/*
 * Decompiled with CFR 0.152.
 */
package org.evrete.spi.minimal;

import java.lang.invoke.MethodHandle;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.StringJoiner;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import org.evrete.api.ExpressionResolver;
import org.evrete.api.FieldReference;
import org.evrete.api.Imports;
import org.evrete.api.IntToValue;
import org.evrete.api.JavaSourceCompiler;
import org.evrete.api.LiteralEvaluator;
import org.evrete.api.LiteralExpression;
import org.evrete.api.NamedType;
import org.evrete.api.RhsContext;
import org.evrete.api.Rule;
import org.evrete.api.RuleCompiledSources;
import org.evrete.api.RuleLiteralData;
import org.evrete.api.RuntimeContext;
import org.evrete.api.spi.LiteralSourceCompiler;
import org.evrete.spi.minimal.AbstractLiteralRhs;
import org.evrete.spi.minimal.BaseRuleClass;
import org.evrete.spi.minimal.ConditionStringTerm;
import org.evrete.spi.minimal.LeastImportantServiceProvider;
import org.evrete.spi.minimal.StringLiteralEncoder;
import org.evrete.util.CompilationException;

public class DefaultLiteralSourceCompiler
extends LeastImportantServiceProvider
implements LiteralSourceCompiler {
    private static final String TAB = "  ";
    private static final String RHS_CLASS_NAME = "Rhs";
    private static final String RHS_INSTANCE_VAR = "ACTION";
    private static final AtomicInteger classCounter = new AtomicInteger(0);
    static final String CLASS_PACKAGE = DefaultLiteralSourceCompiler.class.getPackage().getName() + ".compiled";

    @Override
    public <S extends RuleLiteralData<R>, R extends Rule> Collection<RuleCompiledSources<S, R>> compile(RuntimeContext<?> context, Collection<S> sources) throws CompilationException {
        if (sources.isEmpty()) {
            return Collections.emptyList();
        }
        String stripFlag = context.getConfiguration().getProperty("evrete.spi.compiler.lhs-strip-whitespaces");
        if (stripFlag == null) {
            try {
                return this.compile(context, sources, true);
            }
            catch (CompilationException e) {
                return this.compile(context, sources, false);
            }
        }
        return this.compile(context, sources, Boolean.parseBoolean(stripFlag));
    }

    private <S extends RuleLiteralData<R>, R extends Rule> Collection<RuleCompiledSources<S, R>> compile(RuntimeContext<?> context, Collection<S> sources, boolean stripWhitespaces) throws CompilationException {
        JavaSourceCompiler compiler = context.getSourceCompiler();
        Collection javaSources = sources.stream().map(o -> new RuleSource((RuleLiteralData)o, context, stripWhitespaces)).collect(Collectors.toList());
        Collection result = compiler.compile(javaSources);
        return result.stream().map(compiledSource -> {
            Class<?> ruleClass = compiledSource.getCompiledClass();
            return new RuleCompiledSourcesImpl(ruleClass, (RuleSource)compiledSource.getSource());
        }).collect(Collectors.toList());
    }

    private static class RuleCompiledSourcesImpl<S extends RuleLiteralData<R>, R extends Rule>
    implements RuleCompiledSources<S, R> {
        private final RuleSource<S, R> source;
        private final Collection<LiteralEvaluator> conditions;
        private final Consumer<RhsContext> rhs;

        public RuleCompiledSourcesImpl(Class<?> ruleClass, RuleSource<S, R> source) {
            this.source = source;
            IdentityHashMap<String, ConditionSource> compiledConditions = new IdentityHashMap<String, ConditionSource>();
            for (ConditionSource conditionSource : source.conditionSources) {
                compiledConditions.put(conditionSource.source, conditionSource);
            }
            Collection<String> originalConditions = source.delegate.conditions();
            this.conditions = new ArrayList<LiteralEvaluator>(originalConditions.size());
            for (String condition : originalConditions) {
                ConditionSource compiled = (ConditionSource)compiledConditions.get(condition);
                if (compiled == null) {
                    throw new IllegalStateException("Condition not found or not compiled");
                }
                assert (compiled.source.equals(condition));
                this.conditions.add(new LiteralEvaluatorImpl((Rule)this.getSources().getRule(), compiled, ruleClass));
            }
            this.rhs = source.delegate.rhs() == null ? null : RuleCompiledSourcesImpl.fromClass(ruleClass);
        }

        private static Consumer<RhsContext> fromClass(Class<?> ruleClass) {
            try {
                return (Consumer)ruleClass.getDeclaredField(DefaultLiteralSourceCompiler.RHS_INSTANCE_VAR).get(null);
            }
            catch (IllegalAccessException | NoSuchFieldException e) {
                throw new IllegalStateException("RHS source provided but not compiled");
            }
        }

        @Override
        public S getSources() {
            return this.source.delegate;
        }

        @Override
        public Collection<LiteralEvaluator> conditions() {
            return this.conditions;
        }

        @Override
        public Consumer<RhsContext> rhs() {
            return this.rhs;
        }
    }

    static class RuleSource<S extends RuleLiteralData<R>, R extends Rule>
    implements JavaSourceCompiler.ClassSource {
        private final String className;
        private final String classSimpleName;
        private final S delegate;
        private final Imports imports;
        private final RhsSource rhsSource;
        private final String javaSource;
        private final Collection<ConditionSource> conditionSources;

        RuleSource(S delegate, RuntimeContext<?> context, boolean stripWhitespaces) {
            this.delegate = delegate;
            this.imports = context.getImports();
            this.classSimpleName = "Rule" + classCounter.incrementAndGet();
            this.className = CLASS_PACKAGE + "." + this.classSimpleName;
            AtomicInteger conditionCounter = new AtomicInteger();
            this.conditionSources = delegate.conditions().stream().map(s -> new ConditionSource((Rule)delegate.getRule(), "condition" + conditionCounter.incrementAndGet(), this.classSimpleName, (String)s, context, stripWhitespaces)).collect(Collectors.toList());
            String rhs = delegate.rhs();
            this.rhsSource = rhs == null ? null : new RhsSource((Rule)delegate.getRule(), rhs);
            this.javaSource = this.buildSource();
        }

        private String buildSource() {
            StringBuilder sb = new StringBuilder(4096);
            this.appendHeader(sb);
            if (this.rhsSource != null) {
                this.rhsSource.appendClassVar(sb);
            }
            for (ConditionSource source : this.conditionSources) {
                sb.append(DefaultLiteralSourceCompiler.TAB);
                source.appendDeclaration(sb);
            }
            if (!this.conditionSources.isEmpty()) {
                sb.append("\n").append(DefaultLiteralSourceCompiler.TAB).append("static {\n");
                sb.append(DefaultLiteralSourceCompiler.TAB).append(DefaultLiteralSourceCompiler.TAB).append("try {\n");
                for (ConditionSource source : this.conditionSources) {
                    sb.append(DefaultLiteralSourceCompiler.TAB);
                    sb.append(DefaultLiteralSourceCompiler.TAB);
                    sb.append(DefaultLiteralSourceCompiler.TAB);
                    source.appendDefinition(sb);
                }
                sb.append(DefaultLiteralSourceCompiler.TAB).append(DefaultLiteralSourceCompiler.TAB).append("} catch (Exception e) {\n");
                sb.append(DefaultLiteralSourceCompiler.TAB).append(DefaultLiteralSourceCompiler.TAB).append(DefaultLiteralSourceCompiler.TAB).append("throw new IllegalStateException(e);\n");
                sb.append(DefaultLiteralSourceCompiler.TAB).append(DefaultLiteralSourceCompiler.TAB).append("}\n");
                sb.append(DefaultLiteralSourceCompiler.TAB).append("}\n");
            }
            for (ConditionSource source : this.conditionSources) {
                source.appendHandleMethod(sb);
                source.appendInnerMethod(sb);
                sb.append("\n");
            }
            if (this.rhsSource != null) {
                this.rhsSource.appendClassBody(sb);
            }
            this.appendFooter(sb);
            return sb.toString();
        }

        @Override
        public String binaryName() {
            return this.className;
        }

        @Override
        public String getSource() {
            return this.javaSource;
        }

        private void appendHeader(StringBuilder target) {
            target.append("package ").append(CLASS_PACKAGE).append(";\n\n");
            this.imports.asJavaImportStatements(target);
            String baseClassName = this.delegate.getRule().get("evrete.impl.rule-base-class", (String)BaseRuleClass.class.getName());
            target.append("public final class ").append(this.classSimpleName).append(" extends ").append(baseClassName).append(" {\n");
        }

        private void appendFooter(StringBuilder target) {
            target.append("\n}\n");
        }
    }

    private static class LiteralEvaluatorImpl
    implements LiteralEvaluator {
        private final FieldReference[] descriptor;
        private final String source;
        private final MethodHandle handle;
        private final LiteralExpression sourceExpression;

        public LiteralEvaluatorImpl(Rule rule, ConditionSource compiled, Class<?> ruleClass) {
            this.descriptor = compiled.descriptor;
            this.source = compiled.source;
            this.handle = LiteralEvaluatorImpl.getHandle(ruleClass, compiled.handleName);
            this.sourceExpression = LiteralExpression.of(this.source, rule);
        }

        @Override
        public FieldReference[] descriptor() {
            return this.descriptor;
        }

        @Override
        public LiteralExpression getSource() {
            return this.sourceExpression;
        }

        @Override
        public boolean test(IntToValue values) {
            try {
                return this.handle.invoke(values);
            }
            catch (Throwable t) {
                Object[] args = new Object[this.descriptor.length];
                for (int i = 0; i < args.length; ++i) {
                    args[i] = values.apply(i);
                }
                throw new IllegalStateException("Evaluation exception at '" + this.source + "', arguments: " + Arrays.toString(this.descriptor) + " -> " + Arrays.toString(args), t);
            }
        }

        static MethodHandle getHandle(Class<?> compiledClass, String name) {
            try {
                return (MethodHandle)compiledClass.getDeclaredField(name).get(null);
            }
            catch (IllegalAccessException | NoSuchFieldException e) {
                throw new IllegalStateException("Handle not found", e);
            }
        }
    }

    private static class RhsSource {
        final String rhs;
        final Rule rule;
        final StringJoiner methodArgs;
        final StringJoiner args;

        RhsSource(Rule rule, String rhs) {
            this.rule = rule;
            this.rhs = rhs;
            this.methodArgs = new StringJoiner(", ");
            this.args = new StringJoiner(", ");
            for (NamedType t : rule.getDeclaredFactTypes()) {
                this.methodArgs.add(t.getType().getJavaType() + " " + t.getName());
                this.args.add(t.getName());
            }
        }

        void appendClassVar(StringBuilder target) {
            target.append(DefaultLiteralSourceCompiler.TAB).append("public static final Rhs ").append("ACTION = new Rhs();").append("\n");
        }

        void appendClassBody(StringBuilder target) {
            String[] lines;
            target.append("\n").append(DefaultLiteralSourceCompiler.TAB).append("public static class Rhs extends ").append(AbstractLiteralRhs.class.getName()).append(" {\n\n").append(DefaultLiteralSourceCompiler.TAB).append(DefaultLiteralSourceCompiler.TAB).append("@Override\n").append(DefaultLiteralSourceCompiler.TAB).append(DefaultLiteralSourceCompiler.TAB).append("protected final void doRhs() {\n");
            for (NamedType t : this.rule.getDeclaredFactTypes()) {
                target.append(DefaultLiteralSourceCompiler.TAB).append(DefaultLiteralSourceCompiler.TAB).append(DefaultLiteralSourceCompiler.TAB).append(t.getType().getJavaType()).append(" ").append(t.getName()).append(" = ").append("get(\"").append(t.getName()).append("\");\n");
            }
            target.append(DefaultLiteralSourceCompiler.TAB).append(DefaultLiteralSourceCompiler.TAB).append(DefaultLiteralSourceCompiler.TAB).append("this.doRhs(").append(this.args).append(");\n").append(DefaultLiteralSourceCompiler.TAB).append(DefaultLiteralSourceCompiler.TAB).append("}\n\n");
            target.append(DefaultLiteralSourceCompiler.TAB).append(DefaultLiteralSourceCompiler.TAB).append("private void doRhs(").append(this.methodArgs).append(") {\n");
            String source = "/***** Start RHS source *****/\n" + this.rhs + "\n/****** End RHS source ******/";
            for (String line : lines = source.split("\n")) {
                target.append(DefaultLiteralSourceCompiler.TAB).append(DefaultLiteralSourceCompiler.TAB).append(DefaultLiteralSourceCompiler.TAB).append(line).append("\n");
            }
            target.append(DefaultLiteralSourceCompiler.TAB).append(DefaultLiteralSourceCompiler.TAB).append("}\n").append(DefaultLiteralSourceCompiler.TAB).append("}");
        }
    }

    private static class ConditionSource {
        private static final String DECLARATION_TEMPLATE = "public static final java.lang.invoke.MethodHandle %s;\n";
        private static final String DEFINITION_TEMPLATE = "%s = java.lang.invoke.MethodHandles.lookup().findStatic(%s.class, \"%s\", java.lang.invoke.MethodType.methodType(boolean.class, %s));\n";
        private static final String INNER_METHOD_TEMPLATE = "\n  private static boolean %sInner(%s) {\n    return %s;\n  }\n";
        private static final String HANDLE_METHOD_TEMPLATE = "\n  public static boolean %s(%s) {\n    return %sInner(%s);\n  }\n";
        final String source;
        final String methodName;
        final String handleName;
        final String className;
        private final String replaced;
        private final StringJoiner methodArgs;
        private final StringJoiner argCasts;
        private final FieldReference[] descriptor;

        public ConditionSource(Rule rule, String name, String className, String source, RuntimeContext<?> context, boolean stripWhitespaces) {
            this.className = className;
            this.source = source;
            this.methodName = name;
            this.handleName = name.toUpperCase() + "_HANDLE";
            StringLiteralEncoder encoder = StringLiteralEncoder.of(source, stripWhitespaces);
            ExpressionResolver resolver = context.getExpressionResolver();
            List<ConditionStringTerm> terms = ConditionStringTerm.resolveTerms(encoder.getEncoded(), s -> resolver.resolve((String)s, rule));
            ArrayList<ConditionStringTerm> uniqueReferences = new ArrayList<ConditionStringTerm>();
            ArrayList<ConditionStringTerm> descriptorBuilder = new ArrayList<ConditionStringTerm>();
            Object encodedExpression = encoder.getEncoded().value;
            int accumulatedShift = 0;
            int castVarIndex = 0;
            this.argCasts = new StringJoiner(", ");
            this.methodArgs = new StringJoiner(", ");
            for (ConditionStringTerm term : terms) {
                String original = ((String)encodedExpression).substring(term.start + accumulatedShift, term.end + accumulatedShift);
                String javaArgVar = term.varName;
                String before = ((String)encodedExpression).substring(0, term.start + accumulatedShift);
                String after = ((String)encodedExpression).substring(term.end + accumulatedShift);
                encodedExpression = before + javaArgVar + after;
                accumulatedShift += javaArgVar.length() - original.length();
                if (uniqueReferences.contains(term)) continue;
                descriptorBuilder.add(term);
                Class<?> fieldType = term.field().getValueType();
                this.argCasts.add("(" + fieldType.getCanonicalName() + ") values.apply(" + castVarIndex + ")");
                this.methodArgs.add(fieldType.getCanonicalName() + " " + javaArgVar);
                ++castVarIndex;
                uniqueReferences.add(term);
            }
            this.replaced = encoder.unwrapLiterals((String)encodedExpression);
            this.descriptor = descriptorBuilder.toArray(FieldReference.ZERO_ARRAY);
        }

        void appendDeclaration(StringBuilder target) {
            target.append(String.format(DECLARATION_TEMPLATE, this.handleName));
        }

        void appendHandleMethod(StringBuilder target) {
            target.append(String.format(HANDLE_METHOD_TEMPLATE, this.methodName, IntToValue.class.getName() + " values", this.methodName, this.argCasts));
        }

        void appendInnerMethod(StringBuilder target) {
            target.append(String.format(INNER_METHOD_TEMPLATE, this.methodName, this.methodArgs, this.replaced));
        }

        void appendDefinition(StringBuilder target) {
            target.append(String.format(DEFINITION_TEMPLATE, this.handleName, this.className, this.methodName, IntToValue.class.getName() + ".class"));
        }
    }
}

