package org.kink_lang.kink.internal.compile.javaclassir;

import java.lang.invoke.MethodType;

import java.util.ArrayList;
import java.util.List;
import java.util.Locale;

import org.objectweb.asm.Type;
import org.objectweb.asm.commons.Method;

import org.kink_lang.kink.Val;
import org.kink_lang.kink.Vm;
import org.kink_lang.kink.internal.ovis.OwnVarIndexes;
import org.kink_lang.kink.internal.program.itree.*;

/**
 * Compiles binding-capture fast funs.
 */
public class BindingCaptureFastFunCompiler {

    /** Logger for the class. */
    private static final System.Logger LOGGER
        = System.getLogger(BindingCaptureFastFunCompiler.class.getName());

    /** The vm. */
    private final Vm vm;

    /** The name of the program. */
    private final String programName;

    /** The text of the program. */
    private final String programText;

    /**
     * Constructs a compiler.
     *
     * @param vm the vm.
     * @param programName the name of the program.
     * @param programText the text of the program.
     */
    public BindingCaptureFastFunCompiler(Vm vm, String programName, String programText) {
        this.vm = vm;
        this.programName = programName;
        this.programText = programText;
    }

    /**
     * Compiles the fun.
     *
     * @param fun the fun to compile.
     * @return jcir.
     */
    public JavaClassIr compile(FastFunItree fun) {
        var keySup = new KeyStrSupplier();
        var traceAccum = new TraceAccumulator();
        var jcirAccum = new ChildJcirAccumulator();
        var pcSup = new ProgramCounterSupplier(traceAccum);

        List<Insn> insns = new ArrayList<>(CompilerSupport.PROLOGUE);

        insns.addAll(controlUnchanged(fun, keySup, traceAccum, jcirAccum, pcSup));
        insns.addAll(controlOverridden(fun, keySup, traceAccum, jcirAccum, pcSup));

        insns.addAll(CompilerSupport.EPILOGUE);

        String desc = String.format(Locale.ROOT, "(binding-capture-fast-fun location=%s)",
                vm.location.of(programName, programText, fun.pos()).desc());
        LOGGER.log(System.Logger.Level.TRACE, "insns for {0}: {1}", desc, insns);
        return new JavaClassIr(
                1,
                insns,
                traceAccum.traces(),
                desc,
                jcirAccum.childJcirFactories());
    }

    /**
     * Insns for unchanged controls.
     */
    private List<Insn> controlUnchanged(
            FastFunItree fun,
            KeyStrSupplier keySup,
            TraceAccumulator traceAccum,
            ChildJcirAccumulator jcirAccum,
            ProgramCounterSupplier pcSup) {
        List<Insn> insns = new ArrayList<>();

        // if this.valField0.getOvis().containsPreloadedVar(); goto #control-overridden
        insns.add(new Insn.LoadThis());
        insns.add(new Insn.GetField(JavaClassIr.TYPE_BASE, "valField0", Type.getType(Val.class)));
        insns.add(new Insn.InvokeVirtual(Type.getType(Val.class),
                    new Method("getOvis", Type.getType(OwnVarIndexes.class), new Type[0])));
        insns.add(new Insn.InvokeVirtual(Type.getType(OwnVarIndexes.class),
                    new Method("containsPreloadedVar", Type.BOOLEAN_TYPE, new Type[0])));
        String controlOverridden = keySup.newKeyStr("control-overridden");
        insns.add(new Insn.IfNonZero(controlOverridden));

        var allocationSet = AllocationSet.bindingCaptureControlUnchanged(fun);
        var lvarAccGen = new FastLvarAccessGenerator(allocationSet, keySup, traceAccum);
        var controlGen = new UnchangedControlGenerator(
                vm, programName, programText, keySup, traceAccum);
        var childCompiler = new ValCaptureFastFunCompiler(vm, programName, programText);
        var makeFunGen = new MakeValCaptureFastFunGenerator(
                lvarAccGen,
                childCompiler::compileControlUnchanged,
                AllocationSet::valCaptureControlUnchanged,
                jcirAccum);
        var letRecGen = new InFastFunLetRecGenerator(lvarAccGen, makeFunGen);
        InsnsGenerator insnsGen = new InsnsGenerator(
                vm, programName, programText,
                BindingGenerator.NOT_AVAILABLE,
                lvarAccGen,
                makeFunGen,
                letRecGen,
                controlGen,
                keySup,
                traceAccum,
                pcSup,
                jcirAccum);

        // allocate lvars
        insns.addAll(CompilerSupport.allocateStack(allocationSet.stack().size()));

        List<LocalVar> freeLvars = extractFreeLvars(allocationSet.stack());
        insns.addAll(loadFreeVars(lvarAccGen, freeLvars));

        // fun body
        insns.addAll(insnsGen.generate(fun.body(), ResultContext.TAIL));

        insns.add(new Insn.Mark(controlOverridden));

        return insns;
    }

    /**
     * Insns for overridden controls.
     */
    private List<Insn> controlOverridden(
            FastFunItree fun,
            KeyStrSupplier keySup,
            TraceAccumulator traceAccum,
            ChildJcirAccumulator jcirAccum,
            ProgramCounterSupplier pcSup) {
        List<Insn> insns = new ArrayList<>();

        var allocationSet = AllocationSet.bindingCaptureControlOverridden(fun);
        var lvarAccGen = new FastLvarAccessGenerator(allocationSet, keySup, traceAccum);
        var controlGen = new OverriddenControlGenerator(
                vm, programName, programText, keySup, traceAccum);
        var childCompiler = new ValCaptureFastFunCompiler(vm, programName, programText);
        var makeFunGen = new MakeValCaptureFastFunGenerator(
                lvarAccGen,
                childCompiler::compileControlOverridden,
                AllocationSet::valCaptureControlOverridden,
                jcirAccum);
        var letRecGen = new InFastFunLetRecGenerator(lvarAccGen, makeFunGen);
        InsnsGenerator insnsGen = new InsnsGenerator(
                vm, programName, programText,
                BindingGenerator.NOT_AVAILABLE,
                lvarAccGen,
                makeFunGen,
                letRecGen,
                controlGen,
                keySup,
                traceAccum,
                pcSup,
                jcirAccum);

        // allocate lvars
        insns.addAll(CompilerSupport.allocateStack(allocationSet.stack().size()));

        List<LocalVar> freeLvars = extractFreeLvars(allocationSet.stack());
        insns.addAll(loadFreeVars(lvarAccGen, freeLvars));

        // fun body
        insns.addAll(insnsGen.generate(fun.body(), ResultContext.TAIL));

        return insns;
    }

    /**
     * Returns free vars on the data stack.
     */
    private List<LocalVar> extractFreeLvars(List<LocalVar> stack) {
        return stack.stream()
            .filter(lvar -> lvar instanceof LocalVar.Original)
            .toList();
    }

    /**
     * Insns to load free vars.
     */
    private List<Insn> loadFreeVars(LvarAccessGenerator lvarAccGen, List<LocalVar> freeLvars) {
        List<Insn> insns = new ArrayList<>();

        for (LocalVar lvar : freeLvars) {
            // contParam = get-var[handle](this.v0)
            insns.add(new Insn.LoadThis());
            insns.add(new Insn.GetField(
                        JavaClassIr.TYPE_BASE,
                        "valField0",
                        Type.getType(Val.class)));
            insns.add(new Insn.InvokeDynamic(
                        MethodType.methodType(Val.class, Val.class),
                        InsnsGenerator.BOOTSTRAP_GET_VAR_HANDLE,
                        List.of(vm.sym.handleFor(lvar.name()))));
            insns.add(InsnsGenerator.STORE_CONTPARAM);

            // set local var
            insns.addAll(lvarAccGen.storeLvar(lvar));
        }

        return insns;
    }

}

// vim: et sw=4 sts=4 fdm=marker
