package org.kink_lang.kink.internal.compile.javaclassir;

import java.lang.invoke.CallSite;
import java.lang.invoke.MethodType;
import java.lang.invoke.MethodHandles.Lookup;
import java.util.ArrayList;
import java.util.List;
import java.util.function.BiFunction;

import org.objectweb.asm.Handle;
import org.objectweb.asm.Opcodes;
import org.objectweb.asm.Type;
import org.objectweb.asm.commons.Method;

import org.kink_lang.kink.FunVal;
import org.kink_lang.kink.Val;
import org.kink_lang.kink.Vm;
import org.kink_lang.kink.internal.callstack.CallStack;
import org.kink_lang.kink.internal.callstack.FakeCallTraceCse;
import org.kink_lang.kink.internal.callstack.Location;
import org.kink_lang.kink.internal.callstack.Trace;
import org.kink_lang.kink.internal.compile.bootstrap.ConstBootstrapper;
import org.kink_lang.kink.internal.intrinsicsupport.BranchSupport;
import org.kink_lang.kink.internal.intrinsicsupport.IfSupport;
import org.kink_lang.kink.internal.program.itree.*;

/**
 * Generator of intrinsic implementation of control.
 */
public class UnchangedControlGenerator implements ControlGenerator {

    /** The vm. */
    private final Vm vm;

    /** The program name. */
    private final String programName;

    /** The program text. */
    private final String programText;

    /** The key str supplier. */
    private final KeyStrSupplier keySup;

    /** The trace accumulator. */
    private final TraceAccumulator traceAccum;

    /** Insns of invoking CallStack.pushTailTrace. */
    static final Insn INVOKE_PUSH_TAIL_TRACE = new Insn.InvokeVirtual(
            Type.getType(CallStack.class),
            new Method("pushTailTrace",
                Type.VOID_TYPE,
                new Type[] { Type.getType(Trace.class) }));

    /** Insn of producing true. */
    static final Insn PRODUCE_TRUE = new Insn.InvokeDynamic(
            MethodType.methodType(Val.class),
            new Handle(Opcodes.H_INVOKESTATIC,
                Type.getType(ConstBootstrapper.class).getInternalName(),
                "bootstrapTrue",
                MethodType.methodType(
                    CallSite.class,
                    Lookup.class,
                    String.class,
                    MethodType.class).descriptorString(),
                false),
            List.of());

    /** Insn of producing false. */
    static final Insn PRODUCE_FALSE = new Insn.InvokeDynamic(
            MethodType.methodType(Val.class),
            new Handle(Opcodes.H_INVOKESTATIC,
                Type.getType(ConstBootstrapper.class).getInternalName(),
                "bootstrapFalse",
                MethodType.methodType(
                    CallSite.class,
                    Lookup.class,
                    String.class,
                    MethodType.class).descriptorString(),
                false),
            List.of());

    /** Insn of producing fake call cse. */
    private static final Insn INVOKE_POP_FAKE_CALL_TRACE = new Insn.InvokeVirtual(
            Type.getType(CallStack.class),
            new Method("popFakeCallTrace", Type.VOID_TYPE, new Type[0]));

    /**
     * Constructs a generator.
     *
     * @param vm the vm.
     * @param programName the program name.
     * @param programText the program text.
     * @param keySup the key str supplier.
     * @param traceAccum the trace accumulator.
     */
    public UnchangedControlGenerator(Vm vm,
            String programName,
            String programText,
            KeyStrSupplier keySup,
            TraceAccumulator traceAccum) {
        this.vm = vm;
        this.programName = programName;
        this.programText = programText;
        this.keySup = keySup;
        this.traceAccum = traceAccum;
    }

    @Override
    public List<Insn> preloadedIf(
            IfItree itree,
            BiFunction<Itree, ResultContext, List<Insn>> generate,
            ResultContext resultCtx) {
        // contParam = eval cond
        List<Insn> insns = new ArrayList<>(generate.apply(itree.cond(), ResultContext.NON_TAIL));

        int pos = itree.pos();
        Trace trace = Trace.of(vm.sym.handleFor("if"), new Location(programName, programText, pos));
        if (resultCtx.equals(ResultContext.TAIL)) {
            // callStack.pushTailTrace(trace)
            insns.add(InsnsGenerator.LOAD_CALLSTACK);
            insns.add(new Insn.InvokeDynamic(
                        MethodType.methodType(Trace.class),
                        InsnsGenerator.BOOTSTRAP_TRACE_HANDLE,
                        List.of(traceAccum.add(trace.onTail()))));
            insns.add(INVOKE_PUSH_TAIL_TRACE);
        } else {
            // callStack.pushCse(fakeTraceCse, 0, 0, 0)
            insns.add(InsnsGenerator.LOAD_CALLSTACK);
            insns.add(new Insn.InvokeDynamic(
                        MethodType.methodType(FakeCallTraceCse.class),
                        InsnsGenerator.BOOTSTRAP_FAKE_CALL_CSE,
                        List.of(traceAccum.add(resultCtx.onTailOrNot(trace)))));
            insns.add(new Insn.PushInt(0));
            insns.add(new Insn.PushInt(0));
            insns.add(new Insn.PushInt(0));
            insns.add(InsnsGenerator.INVOKE_PUSH_CSE);
        }

        // if contParam == true; goto #cond-is-true
        insns.add(InsnsGenerator.LOAD_CONTPARAM);
        insns.add(PRODUCE_TRUE);
        String condIsTrue = keySup.newKeyStr("cond-is-true");
        insns.add(new Insn.IfEq(Type.getType(Val.class), condIsTrue));

        // if contParam == false; goto #cond-is-false
        insns.add(InsnsGenerator.LOAD_CONTPARAM);
        insns.add(PRODUCE_FALSE);
        String condIsFalse = keySup.newKeyStr("cond-is-false");
        insns.add(new Insn.IfEq(Type.getType(Val.class), condIsFalse));

        // dataStack.removeFromOffset(0)
        insns.add(InsnsGenerator.LOAD_DATASTACK);
        insns.add(new Insn.PushInt(0));
        insns.add(InsnsGenerator.INVOKE_REMOVE_FROM_OFFSET);

        // dataStack.push(nada)
        insns.add(InsnsGenerator.LOAD_DATASTACK);
        insns.add(InsnsGenerator.PRODUCE_NADA);
        insns.add(InsnsGenerator.INVOKE_PUSH_TO_DATASTACK);

        // stackMachine.transitionToCall(
        //  CoreIfFuns.condNotBool(vm, contParam)); return
        insns.add(InsnsGenerator.LOAD_STACKMACHINE);
        insns.add(new Insn.LoadThis());
        insns.add(new Insn.GetField(JavaClassIr.TYPE_BASE, "vm", Type.getType(Vm.class)));
        insns.add(InsnsGenerator.LOAD_CONTPARAM);
        insns.add(new Insn.InvokeStatic(Type.getType(IfSupport.class),
                    new Method("condNotBool", Type.getType(FunVal.class),
                        new Type[] { Type.getType(Vm.class), Type.getType(Val.class) })));
        insns.add(InsnsGenerator.INVOKE_TRANSITION_TO_CALL);
        insns.add(new Insn.ReturnValue());

        ResultContext contCtx = resultCtx.actAsTail();

        // #cond-is-true:
        insns.add(new Insn.Mark(condIsTrue));
        insns.addAll(generate.apply(itree.trueFun().body(), contCtx));
        String endif = keySup.newKeyStr("endif");
        if (! resultCtx.equals(ResultContext.TAIL)) {
            insns.add(new Insn.GoTo(endif));
        }

        // #cond-is-false:
        insns.add(new Insn.Mark(condIsFalse));
        Itree falseCont = itree.falseFun().map(FastFunItree::body)
            .orElse(new NadaItree(pos));
        insns.addAll(generate.apply(falseCont, contCtx));

        if (! resultCtx.equals(ResultContext.TAIL)) {
            // #endif:
            insns.add(new Insn.Mark(endif));
            // callStack.popFakeCallTrace()
            insns.add(InsnsGenerator.LOAD_CALLSTACK);
            insns.add(new Insn.InvokeVirtual(Type.getType(CallStack.class),
                        new Method("popFakeCallTrace", Type.VOID_TYPE, new Type[0])));
        }
        return insns;
    }

    @Override
    public List<Insn> branch(
            BranchItree itree,
            BiFunction<Itree, ResultContext, List<Insn>> generate,
            ResultContext resultCtx) {
        List<Insn> insns = new ArrayList<>();
        int pos = itree.pos();
        Trace trace = Trace.of(
                vm.sym.handleFor("branch"),
                new Location(programName, programText, pos));

        insns.addAll(resultCtx.equals(ResultContext.NON_TAIL)
                ? pushFakeTraceCse(trace)
                : pushTailTrace(trace));

        String endBranch = keySup.newKeyStr("end-branch");
        insns.addAll(evalCondThenPairs(itree.condThenPairs(), generate, resultCtx, endBranch));

        // tail call no-matching-cond; return
        insns.addAll(tailCallNoMatchingCond());
        insns.add(new Insn.ReturnValue());

        insns.addAll(endBranch(resultCtx, endBranch));
        return insns;
    }

    @Override
    public List<Insn> branchWithElse(
            BranchWithElseItree itree,
            BiFunction<Itree, ResultContext, List<Insn>> generate,
            ResultContext resultCtx) {
        List<Insn> insns = new ArrayList<>();
        int pos = itree.pos();
        Trace trace = Trace.of(
                vm.sym.handleFor("branch"),
                new Location(programName, programText, pos));

        insns.addAll(resultCtx.equals(ResultContext.NON_TAIL)
                ? pushFakeTraceCse(trace)
                : pushTailTrace(trace));

        String endBranch = keySup.newKeyStr("end-branch");
        insns.addAll(evalCondThenPairs(itree.condThenPairs(), generate, resultCtx, endBranch));
        insns.addAll(generate.apply(itree.elseThenFun().body(), resultCtx.actAsTail()));

        insns.addAll(endBranch(resultCtx, endBranch));
        return insns;
    }

    /**
     * Generates evaluation of cond-then pairs.
     */
    private List<Insn> evalCondThenPairs(
            List<CondThenPair> condThenPairs,
            BiFunction<Itree, ResultContext, List<Insn>> generate,
            ResultContext resultCtx,
            String endBranch) {
        List<Insn> insns = new ArrayList<>();

        int callByHost = traceAccum.add(Trace.of(vm.sym.handleFor("..call by host..")));
        for (int pairInd = 0; pairInd < condThenPairs.size(); ++ pairInd) {
            CondThenPair condThenPair = condThenPairs.get(pairInd);

            // callStack.pushCse(trace of call, 0, 0, 0)
            insns.addAll(InsnsGenerator.pushFakeCall(callByHost));

            // contParam = cond
            insns.addAll(generate.apply(condThenPair.condFun().body(),
                        ResultContext.SYNTHETIC_TAIL));

            // callStack.popFakeCallTrace()
            insns.addAll(POP_FAKE_CALL_TRACE);

            // if contParam == true; goto #condN-is-true
            insns.add(InsnsGenerator.LOAD_CONTPARAM);
            insns.add(PRODUCE_TRUE);
            String condIsTrue = keySup.newKeyStr("cond" + (pairInd + 1) + "-is-true");
            insns.add(new Insn.IfEq(Type.getType(Val.class), condIsTrue));

            // if contParam == false; goto #condN-is-false
            insns.add(InsnsGenerator.LOAD_CONTPARAM);
            insns.add(PRODUCE_FALSE);
            String condIsFalse = keySup.newKeyStr("cond" + (pairInd + 1) + "-is-false");
            insns.add(new Insn.IfEq(Type.getType(Val.class), condIsFalse));

            // tail call cond-not-bool; return
            insns.addAll(tailCallCondNotBool(pairInd));
            insns.add(new Insn.ReturnValue());

            // #condN-is-true:
            insns.add(new Insn.Mark(condIsTrue));

            // eval $then body
            insns.addAll(generate.apply(condThenPair.thenFun().body(), resultCtx.actAsTail()));
            if (! resultCtx.equals(ResultContext.TAIL)) {
                insns.add(new Insn.GoTo(endBranch));
            }

            // #condN-is-false:
            insns.add(new Insn.Mark(condIsFalse));
        }

        return insns;
    }

    /**
     * Insns to push a fake trace cse.
     */
    private List<Insn> pushFakeTraceCse(Trace trace) {
        // callStack.pushCse(fakeCallCse, 0, 0, 0)
        return List.of(InsnsGenerator.LOAD_CALLSTACK,
                new Insn.InvokeDynamic(
                    MethodType.methodType(FakeCallTraceCse.class),
                    InsnsGenerator.BOOTSTRAP_FAKE_CALL_CSE,
                    List.of(traceAccum.add(trace))),
                new Insn.PushInt(0),
                new Insn.PushInt(0),
                new Insn.PushInt(0),
                InsnsGenerator.INVOKE_PUSH_CSE);
    }

    /**
     * Insns to push a tail trace.
     */
    private List<Insn> pushTailTrace(Trace trace) {
        // callStack.pushTailTrace(trace)
        return List.of(InsnsGenerator.LOAD_CALLSTACK,
                new Insn.InvokeDynamic(
                    MethodType.methodType(Trace.class),
                    InsnsGenerator.BOOTSTRAP_TRACE_HANDLE,
                    List.of(traceAccum.add(trace.onTail()))),
                INVOKE_PUSH_TAIL_TRACE);
    }

    /** dataStack.removeFromOffset(0). */
    private static final List<Insn> REMOVE_FROM_0 = List.of(
            InsnsGenerator.LOAD_DATASTACK,
            new Insn.PushInt(0),
            InsnsGenerator.INVOKE_REMOVE_FROM_OFFSET);

    /** Push nada to dataStack. */
    private static final List<Insn> PUSH_NADA = List.of(
            InsnsGenerator.LOAD_DATASTACK,
            InsnsGenerator.PRODUCE_NADA,
            InsnsGenerator.INVOKE_PUSH_TO_DATASTACK);

    /** callStack.popFakeCallTrace(). */
    private static final List<Insn> POP_FAKE_CALL_TRACE = List.of(
            InsnsGenerator.LOAD_CALLSTACK,
            INVOKE_POP_FAKE_CALL_TRACE);

    /**
     * Insns to tail-call condNotBool fun.
     */
    private List<Insn> tailCallCondNotBool(int pairInd) {
        int argInd = pairInd * 2;
        List<Insn> insns = new ArrayList<>();

        insns.addAll(REMOVE_FROM_0);
        insns.addAll(PUSH_NADA);

        // stackMachine.transitionToCall(
        //  BranchFuns.condNotBool(vm, argInd, contParam)); return
        insns.add(InsnsGenerator.LOAD_STACKMACHINE);
        insns.add(new Insn.LoadThis());
        insns.add(new Insn.GetField(JavaClassIr.TYPE_BASE, "vm", Type.getType(Vm.class)));
        insns.add(new Insn.PushInt(argInd));
        insns.add(InsnsGenerator.LOAD_CONTPARAM);
        insns.add(new Insn.InvokeStatic(Type.getType(BranchSupport.class),
                    new Method("condNotBool", Type.getType(FunVal.class),
                        new Type[] {
                            Type.getType(Vm.class),
                            Type.INT_TYPE,
                            Type.getType(Val.class) })));
        insns.add(InsnsGenerator.INVOKE_TRANSITION_TO_CALL);

        return insns;
    }

    /**
     * Insns to tail-call noMatchingCond fun.
     */
    private List<Insn> tailCallNoMatchingCond() {
        List<Insn> insns = new ArrayList<>();
        insns.addAll(REMOVE_FROM_0);
        insns.addAll(PUSH_NADA);

        // stackMachine.transitionToCall(
        //  BranchFuns.noMatchingCond(vm)); return
        insns.add(InsnsGenerator.LOAD_STACKMACHINE);
        insns.add(new Insn.LoadThis());
        insns.add(new Insn.GetField(JavaClassIr.TYPE_BASE, "vm", Type.getType(Vm.class)));
        insns.add(new Insn.InvokeStatic(Type.getType(BranchSupport.class),
                    new Method("noMatchingCond", Type.getType(FunVal.class),
                        new Type[] { Type.getType(Vm.class) })));
        insns.add(InsnsGenerator.INVOKE_TRANSITION_TO_CALL);
        return insns;
    }

    /**
     * Insns to end the branch.
     */
    private List<Insn> endBranch(ResultContext resultCtx, String endBranch) {
        List<Insn> insns = new ArrayList<>();
        if (! resultCtx.equals(ResultContext.TAIL)) {
            insns.add(new Insn.Mark(endBranch));
            if (resultCtx.equals(ResultContext.NON_TAIL)) {
                insns.addAll(POP_FAKE_CALL_TRACE);
            }
        }
        return insns;
    }

    @Override
    public List<Insn> traitNewVal(
            TraitNewValItree itree,
            BiFunction<Itree, ResultContext, List<Insn>> generate,
            ResultContext resultCtx) {
        List<Insn> insns = new ArrayList<>(
                NewVal.traitNewValCommonPart(vm, itree, generate, resultCtx,
                    programName, programText,
                    keySup, traceAccum));
        insns.addAll(resultCtx.returnOnTail());
        return insns;
    }

    @Override
    public List<Insn> noTraitNewVal(
            NoTraitNewValItree itree,
            BiFunction<Itree, ResultContext, List<Insn>> generate,
            ResultContext resultCtx) {
        List<Insn> insns = new ArrayList<>(
                NewVal.noTraitNewValCommonPart(vm, itree, generate, keySup));
        insns.addAll(resultCtx.returnOnTail());
        return insns;
    }

}

// vim: et sw=4 sts=4 fdm=marker
