package org.kink_lang.kink.internal.compile.javaclassir;

import java.lang.invoke.CallSite;
import java.lang.invoke.MethodHandles.Lookup;
import java.lang.invoke.MethodType;
import java.util.ArrayList;
import java.util.List;
import java.util.function.BiFunction;
import java.util.stream.Stream;

import org.objectweb.asm.Handle;
import org.objectweb.asm.Opcodes;
import org.objectweb.asm.Type;

import org.kink_lang.kink.Val;
import org.kink_lang.kink.Vm;
import org.kink_lang.kink.internal.compile.bootstrap.ConstBootstrapper;
import org.kink_lang.kink.internal.program.itree.*;

/**
 * Generates insns for controls for the case when
 * intrinsic implementations cannot be used.
 */
public class OverriddenControlGenerator implements ControlGenerator {

    /** The vm. */
    private final Vm vm;

    /** The program name. */
    private final String programName;

    /** The program text. */
    private final String programText;

    /** The key str supplier. */
    private final KeyStrSupplier keySup;

    /** The trace accumulator. */
    private final TraceAccumulator traceAccum;

    /** Handler of ConstBootstrapper.producePreloaded. */
    static final Handle BOOTSTRAP_PRELOADED_HANDLE = new Handle(Opcodes.H_INVOKESTATIC,
            Type.getType(ConstBootstrapper.class).getInternalName(),
            "bootstrapPreloaded",
            MethodType.methodType(
                CallSite.class,
                Lookup.class,
                String.class,
                MethodType.class,
                String.class).descriptorString(),
            false);

    /** Insn of producing new_val. */
    static final Insn PRODUCE_NEW_VAL = new Insn.InvokeDynamic(
            MethodType.methodType(Val.class),
            BOOTSTRAP_PRELOADED_HANDLE,
            List.of("new_val"));

    /**
     * Constructs a generator.
     *
     * @param vm the vm.
     * @param programName the program name.
     * @param programText the program text.
     * @param keySup the key str supplier.
     * @param traceAccum the trace accumulator.
     */
    public OverriddenControlGenerator(Vm vm,
            String programName,
            String programText,
            KeyStrSupplier keySup,
            TraceAccumulator traceAccum) {
        this.vm = vm;
        this.programName = programName;
        this.programText = programText;
        this.keySup = keySup;
        this.traceAccum = traceAccum;
    }

    @Override
    public List<Insn> preloadedIf(
            IfItree itree,
            BiFunction<Itree, ResultContext, List<Insn>> generate,
            ResultContext resultCtx) {
        List<ItreeElem> args = new ArrayList<>(3);
        args.add(itree.cond());
        args.add(itree.trueFun());
        itree.falseFun().ifPresent(args::add);
        int pos = itree.pos();
        SymcallItree fallback = new SymcallItree(
                new LderefItree(new LocalVar.Original("if"), pos), "if",
                new NadaItree(pos), args, pos);
        return generate.apply(fallback, resultCtx);
    }

    @Override
    public List<Insn> branch(
            BranchItree itree,
            BiFunction<Itree, ResultContext, List<Insn>> generate,
            ResultContext resultCtx) {
        List<ItreeElem> args = condThenPairsToArgs(itree.condThenPairs());
        int pos = itree.pos();
        SymcallItree fallback = new SymcallItree(
                lderefBranch(pos), "branch",
                new NadaItree(pos), args, pos);
        return generate.apply(fallback, resultCtx);
    }

    @Override
    public List<Insn> branchWithElse(
            BranchWithElseItree itree,
            BiFunction<Itree, ResultContext, List<Insn>> generate,
            ResultContext resultCtx) {
        List<ItreeElem> args = new ArrayList<>(condThenPairsToArgs(itree.condThenPairs()));
        int pos = itree.pos();
        args.add(new LderefItree(new LocalVar.Original("true"), pos));
        args.add(itree.elseThenFun());
        SymcallItree fallback = new SymcallItree(
                lderefBranch(pos), "branch",
                new NadaItree(pos), args, pos);
        return generate.apply(fallback, resultCtx);
    }

    /**
     * Lderef itree for $branch.
     */
    private LderefItree lderefBranch(int pos) {
        return new LderefItree(new LocalVar.Original("branch"), pos);
    }

    /**
     * Args out of cond-then pairs.
     */
    private List<ItreeElem> condThenPairsToArgs(List<CondThenPair> condThenPairs) {
        return condThenPairs.stream()
            .flatMap(pair -> Stream.of((ItreeElem) pair.condFun(), (ItreeElem) pair.thenFun()))
            .toList();
    }

    @Override
    public List<Insn> traitNewVal(
            TraitNewValItree itree,
            BiFunction<Itree, ResultContext, List<Insn>> generate,
            ResultContext resultCtx) {
        int pos = itree.pos();
        SymcallItree fallback = fallbackScallTraitNewVal(itree);
        String end = keySup.newKeyStr("end");

        List<Insn> insns = new ArrayList<>(evalFallback(fallback, generate, resultCtx, pos, end));
        insns.addAll(NewVal.traitNewValCommonPart(vm, itree, generate, resultCtx,
                    programName, programText,
                    keySup, traceAccum));
        insns.addAll(resultCtx.returnOnTailOrMark(end));
        return insns;
    }

    @Override
    public List<Insn> noTraitNewVal(
            NoTraitNewValItree itree,
            BiFunction<Itree, ResultContext, List<Insn>> generate,
            ResultContext resultCtx) {
        int pos = itree.pos();
        SymcallItree fallback = fallbackScallNoTraitNewVal(itree);
        String end = keySup.newKeyStr("end");

        List<Insn> insns = new ArrayList<>(evalFallback(fallback, generate, resultCtx, pos, end));
        insns.addAll(NewVal.noTraitNewValCommonPart(vm, itree, generate, keySup));
        insns.addAll(resultCtx.returnOnTailOrMark(end));
        return insns;
    }

    /**
     * Generates invocation of new_val method as a fallback.
     */
    private SymcallItree fallbackScallTraitNewVal(TraitNewValItree itree) {
        LderefItree newValLderef = new LderefItree(new LocalVar.Original("new_val"), itree.pos());
        List<ItreeElem> args = new ArrayList<>();
        args.add(new ItreeElem.Spread(itree.trait(), itree.spreadPos()));
        for (SymValPair symValPair : itree.symValPairs()) {
            args.add(new StrItree(symValPair.sym(), itree.pos()));
            args.add(symValPair.val());
        }
        return new SymcallItree(newValLderef,
                "new_val",
                new NadaItree(itree.pos()),
                args,
                itree.pos());
    }

    /**
     * Makes sym-call itree equivalent with no-trait new_val for fallback.
     */
    private SymcallItree fallbackScallNoTraitNewVal(NoTraitNewValItree itree) {
        LderefItree newValLderef = new LderefItree(new LocalVar.Original("new_val"), itree.pos());
        List<ItreeElem> args = new ArrayList<>();
        for (SymValPair symValPair : itree.symValPairs()) {
            args.add(new StrItree(symValPair.sym(), itree.pos()));
            args.add(symValPair.val());
        }
        return new SymcallItree(newValLderef,
                "new_val",
                new NadaItree(itree.pos()),
                args,
                itree.pos());
    }

    /**
     * Generates insns to call a fun as fallback.
     */
    private List<Insn> evalFallback(
            SymcallItree fallback,
            BiFunction<Itree, ResultContext, List<Insn>> generate,
            ResultContext resultCtx,
            int pos,
            String end) {
        // contParam = $new_val
        LderefItree newValLderef = new LderefItree(new LocalVar.Original("new_val"), pos);
        List<Insn> insns = new ArrayList<>(generate.apply(newValLderef, ResultContext.NON_TAIL));

        // if (contParam == new_val) goto #intrinsic-new_val
        insns.add(InsnsGenerator.LOAD_CONTPARAM);
        insns.add(PRODUCE_NEW_VAL);
        String intrinsicNewVal = keySup.newKeyStr("intrinsic-new_val");
        insns.add(new Insn.IfEq(Type.getType(Val.class), intrinsicNewVal));

        // call fallback; goto #end
        insns.addAll(generate.apply(fallback, resultCtx));
        if (! resultCtx.equals(ResultContext.TAIL)) {
            insns.add(new Insn.GoTo(end));
        }

        // #intrinsic-new_val:
        insns.add(new Insn.Mark(intrinsicNewVal));

        return insns;
    }

}

// vim: et sw=4 sts=4 fdm=marker
