package org.kink_lang.kink;

import java.util.Locale;
import java.util.HashMap;
import java.util.Map;

import org.kink_lang.kink.hostfun.CallContext;
import org.kink_lang.kink.hostfun.HostResult;
import org.kink_lang.kink.internal.callstack.CallStackSlice;
import org.kink_lang.kink.internal.callstack.Trace;
import org.kink_lang.kink.internal.function.ThrowingFunction3;
import org.kink_lang.kink.internal.function.ThrowingFunction4;

/**
 * Support class of kont tags.
 */
class KontTagHelper {

    /** The vm. */
    private final Vm vm;

    /** Trace for kont tags. */
    Trace trace;

    /** Shared vars of exception vals. */
    SharedVars sharedVars;

    /**
     * Construts the helper.
     */
    KontTagHelper(Vm vm) {
        this.vm = vm;
    }

    /**
     * Initialize the helper.
     */
    void init() {
        this.trace = Trace.of(vm.sym.handleFor("..kont tag.."));
        Map<Integer, Val> vars = new HashMap<>();
        addMethod1(vars, "Kont_tag", "reset", "$thunk", this::resetMethod);
        addMethod1(vars, "Kont_tag", "shift", "$abort", this::shiftMethod);
        addMethod0(vars, "Kont_tag", "can_shift?", this::canShiftMethod);
        addMethod0(vars, "Kont_tag", "repr", this::reprMethod);
        this.sharedVars = vm.sharedVars.of(vars);
    }

    /**
     * Add nullary method.
     */
    private void addMethod0(
            Map<Integer, Val> vars,
            String recvDesc,
            String name,
            ThrowingFunction3<CallContext, String, KontTagVal, HostResult> action) {
        int symHandle = vm.sym.handleFor(name);
        String desc = String.format(Locale.ROOT, "%s.%s)", recvDesc, name);
        vars.put(symHandle, vm.fun.make(desc).take(0).action(c -> {
            Val recv = c.recv();
            if (! (recv instanceof KontTagVal kontTag)) {
                return c.call(vm.graph.raiseFormat("{}: {} must be kont_tag, but got {}",
                            vm.graph.of(vm.str.of(desc)),
                            vm.graph.of(vm.str.of(recvDesc)),
                            vm.graph.repr(recv)));
            }
            return action.apply(c, desc, kontTag);
        }));
    }

    /**
     * Add unary method.
     */
    private void addMethod1(
            Map<Integer, Val> vars,
            String recvDesc,
            String name,
            String arg0Desc,
            ThrowingFunction4<CallContext, String, KontTagVal, Val, HostResult> action) {
        int symHandle = vm.sym.handleFor(name);
        String desc = String.format(Locale.ROOT, "%s.%s(%s)", recvDesc, name, arg0Desc);
        vars.put(symHandle, vm.fun.make(desc).take(1).action(c -> {
            Val recv = c.recv();
            if (! (recv instanceof KontTagVal kontTag)) {
                return c.call(vm.graph.raiseFormat("{}: {} must be kont_tag, but got {}",
                            vm.graph.of(vm.str.of(desc)),
                            vm.graph.of(vm.str.of(recvDesc)),
                            vm.graph.repr(recv)));
            }
            Val arg0 = c.arg(0);
            return action.apply(c, desc, kontTag, arg0);
        }));
    }

    /**
     * Implementation of Kont_tag.reset.
     */
    private HostResult resetMethod(CallContext c, String desc, KontTagVal kontTag, Val arg) {
        if (! (arg instanceof FunVal thunk)) {
            return c.call(vm.graph.raiseFormat("{}: $thunk must be fun, but got {}",
                        vm.graph.of(vm.str.of(desc)),
                        vm.graph.repr(arg)));
        }
        return c.call(resetDelegateFun(kontTag, thunk));
    }

    /**
     * Delegate reset work to this.
     */
    private FunVal resetDelegateFun(KontTagVal kontTag, FunVal thunk) {
        return new FunVal(vm) {
            @Override void run(StackMachine stackMachine) {
                var callStack = stackMachine.getCallStack();
                if (! callStack.isModerateSize()) {
                    stackMachine.transitionToRaise("call stack overflow");
                    return;
                }
                callStack.pushCse(kontTag, 0, 0, 0);

                var dataStack = stackMachine.getDataStack();
                dataStack.removeFromOffset(0);
                dataStack.push(vm.nada);

                stackMachine.transitionToCall(thunk);
            }
        };
    }

    /**
     * Implementation of Kont_tag.shift.
     */
    private HostResult shiftMethod(CallContext c, String desc, KontTagVal kontTag, Val arg) {
        if (! (arg instanceof FunVal abort)) {
            return c.call(vm.graph.raiseFormat("{}: $abort must be fun, but got {}",
                        vm.graph.of(vm.str.of(desc)),
                        vm.graph.repr(arg)));
        }
        return c.call(shiftDelegateFun(desc, kontTag, abort));
    }

    /**
     * Delegate shift work to this.
     */
    private FunVal shiftDelegateFun(String desc, KontTagVal kontTag, FunVal abort) {
        return new FunVal(vm) {
            @Override void run(StackMachine stackMachine) {
                var callStack = stackMachine.getCallStack();
                var dataStack = stackMachine.getDataStack();
                if (! callStack.canAbort(kontTag)) {
                    FunVal fun = vm.fun.make().action(c -> c.call(vm.graph.raiseFormat(
                                    "{}: {} not found on stack",
                                    vm.graph.of(vm.str.of(desc)),
                                    vm.graph.repr(kontTag))));
                    dataStack.removeFromOffset(0);
                    dataStack.push(vm.nada);
                    stackMachine.transitionToCall(fun);
                    return;
                }

                CallStackSlice callSlice = callStack.abort(kontTag);

                dataStack.removeFromOffset(0);
                int dataStackUsage = callSlice.dataStackUsage();
                Val[] dataSlice = stackMachine.getDataStack().sliceTop(dataStackUsage);
                dataStack.decreaseBp(dataStackUsage);
                dataStack.removeFromOffset(0);
                FunVal resumeFun = resumeFun(callSlice, dataSlice);

                // recv
                dataStack.push(vm.nada);
                // arg0 = resume fun
                dataStack.push(resumeFun);

                stackMachine.transitionToCall(abort);
            }
        };
    }

    /**
     * Kont resume fun.
     */
    private FunVal resumeFun(CallStackSlice callSlice, Val[] dataSlice) {
        return vm.fun.make().take(1).action(c -> {
            Val arg = c.arg(0);
            return c.call(resumeDelegate(arg, callSlice, dataSlice));
        });
    }

    /**
     * Delegate resume to this.
     */
    private FunVal resumeDelegate(Val arg, CallStackSlice callSlice, Val[] dataSlice) {
        return new FunVal(vm) {
            @Override
            void run(StackMachine stackMachine) {
                var callStack = stackMachine.getCallStack();
                var dataStack = stackMachine.getDataStack();
                dataStack.removeFromOffset(0);

                if (! callStack.canReplay(callSlice)) {
                    stackMachine.transitionToRaise("call stack overflow");
                    return;
                }

                if (! dataStack.ensureCapaSpPlus(dataSlice.length)) {
                    stackMachine.transitionToRaise("data stack overflow");
                    return;
                }

                dataStack.pushAll(dataSlice);
                dataStack.increaseBp(dataSlice.length);
                callStack.replay(callSlice);
                stackMachine.transitionToResult(arg);
            }
        };
    }

    /**
     * Implementation of Kont_tag.can_shift?.
     */
    private HostResult canShiftMethod(CallContext c, String desc, KontTagVal kontTag) {
        return c.call(canShiftDelegate(kontTag));
    }

    /**
     * Delegate can_shift? to this.
     */
    private FunVal canShiftDelegate(KontTagVal kontTag) {
        return new FunVal(vm) {
            @Override
            void run(StackMachine stackMachine) {
                var callStack = stackMachine.getCallStack();
                stackMachine.transitionToResult(vm.bool.of(callStack.canAbort(kontTag)));
            }
        };
    }

    /**
     * Implementation of Kont_tag.repr.
     */
    private HostResult reprMethod(CallContext c, String desc, KontTagVal kontTag) {
        return vm.str.of(String.format(Locale.ROOT, "(kont_tag val_id=%s)",
                    kontTag.identity()));
    }


}

// vim: et sw=4 sts=4 fdm=marker
