package org.kink_lang.kink;

import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.stream.Collectors;

import javax.annotation.Nullable;

import org.kink_lang.kink.hostfun.CallContext;
import org.kink_lang.kink.hostfun.CallFlowToArgs;
import org.kink_lang.kink.hostfun.CallFlowToRecv;
import org.kink_lang.kink.hostfun.CallFlowToOn;
import org.kink_lang.kink.hostfun.HostFunReaction;
import org.kink_lang.kink.hostfun.HostResult;
import org.kink_lang.kink.hostfun.graph.GraphNode;
import org.kink_lang.kink.internal.contract.Preconds;
import org.kink_lang.kink.internal.callstack.ResumeCse;
import org.kink_lang.kink.internal.callstack.CallStack;
import org.kink_lang.kink.internal.callstack.CallStackSlice;
import org.kink_lang.kink.internal.callstack.Lnums;
import org.kink_lang.kink.internal.callstack.HostResumeCse;
import org.kink_lang.kink.internal.callstack.Trace;
import org.kink_lang.kink.internal.intrinsicsupport.ArgsSupport;

/**
 * The stack machine evaluator.
 */
class StackMachine {

    /** The vm. */
    private final Vm vm;

    /** The call stack. */
    private final CallStack callStack;

    /** The dataStack. */
    private final DataStack dataStack;

    /** Handle of {@code raise}. */
    private final int raiseHandle;

    /** Handle of {@code call}. */
    private final int callByHostHandle;

    /**
     * Constructs a stack machine.
     */
    StackMachine(Vm vm, CallStack callStack, DataStack dataStack) {
        this.vm = vm;
        this.callStack = callStack;
        this.dataStack = dataStack;
        this.raiseHandle = vm.sym.handleFor("raise");
        this.callByHostHandle = vm.sym.handleFor("..call by host..");
    }

    /**
     * Constructs a stack machine.
     */
    StackMachine(Vm vm) {
        this(vm, new CallStack(100, 50_000, 12), new DataStack(vm, 1000, 500_000));
    }

    /**
     * Returns the vm.
     */
    Vm getVm() {
        return this.vm;
    }

    /**
     * The outcome of the execution.
     */
    sealed interface Outcome {

        /**
         * StackMachine terminates with a val.
         *
         * @param val the returned val.
         */
        record Returned(Val val) implements Outcome {
        }

        /**
         * StackMachine terminates with an exception.
         *
         * @param exception the raised exception.
         */
        record Raised(ExceptionVal exception) implements Outcome {
        }

    }

    /**
     * Runs the stack machine with the function call.
     *
     * <p>This method can be invoked for a StackMachine instane only once.</p>
     */
    Outcome run(FunVal fun) {
        var trace = Trace.of(vm.sym.handleFor("..root.."));
        ResumeCse terminateFrame = new ResumeCse() {
            @Override
            public Trace trace(int programCounter) {
                return trace;
            }
        };
        callStack.pushCse(terminateFrame, 0, 0, 0);
        dataStack.push(vm.nada);
        transitionToCall(fun);
        return mainLoop();
    }

    //
    // Current call
    //

    /** The count of args for the current call. */
    private int argCount;

    //
    // Accessors
    //

    /**
     * Returns the callStack.
     */
    CallStack getCallStack() {
        return this.callStack;
    }

    /**
     * Returns the dataStack.
     */
    DataStack getDataStack() {
        return this.dataStack;
    }

    /**
     * Resets the argCount as if the current sp is the end of the args.
     */
    void resetArgCount() {
        int recvCount = 1;
        this.argCount = this.dataStack.topOffset() - recvCount;
    }

    //
    // Main loop
    //

    /**
     * State of the main loop.
     */
    static final class State {

        /**
         * Should not be instantiated.
         */
        State() {
            throw new UnsupportedOperationException("should not be instantiated");
        }

        /** State to handle a fun call. */
        public static final int CALL = 0;

        /** State to consume a result val. */
        public static final int CONSUME = 1;

        /** State to handle a rection. */
        public static final int  HOST_RESULT = 2;

        /** State to terminate the main loop. */
        public static final int TERMINATE = 3;

        /** Intermediate state between state transitions. */
        public static final int BETWEEN_TRANSITION = 4;

    }

    /** The state of the main loop. */
    private int state = State.BETWEEN_TRANSITION;

    /** The arg of state transition. */
    @Nullable
    private Object stateTransitionArg;

    /**
     * The main loop.
     */
    private Outcome mainLoop() {
        while (this.state != State.TERMINATE) {
            int stateCopy = this.state;
            this.state = State.BETWEEN_TRANSITION;

            Object argCopy = this.stateTransitionArg;
            this.stateTransitionArg = null;

            handle(stateCopy, argCopy);
        }

        return (Outcome) this.stateTransitionArg;
    }

    /**
     * Transition to the next state.
     */
    private void transition(int state, Object arg) {
        this.state = state;
        this.stateTransitionArg = arg;
    }

    /**
     * Handles the state.
     */
    void handle(int state, Object arg) {
        switch (state) {
            case State.CALL: {
                this.handleCall((FunVal) arg);
                break;
            }
            case State.CONSUME: {
                this.handleResult((Val) arg);
                break;
            }
            case State.HOST_RESULT: {
                this.handleHostResult((HostResultCore) arg);
                break;
            }
            default: {
                throw new UnsupportedOperationException();
            }
        }
    }

    // State.TERMINATE {{{1

    /**
     * Transitions to TERMINATE state with the stack machine result.
     */
    void transitionToTerminateWithVal(Val stackMachineResult) {
        transition(State.TERMINATE, new Outcome.Returned(stackMachineResult));
    }

    /**
     * Transitions to TERMINATE state with the failure result.
     */
    void transitionToTerminatedWithException(ExceptionVal exc) {
        transition(State.TERMINATE, new Outcome.Raised(exc));
    }

    // }}}1
    // State.CALL {{{1

    /**
     * Transitions to CALL state.
     */
    void transitionToCall(FunVal fun) {
        transition(State.CALL, fun);
    }
    /**
     * Handles a fun call.
     */
    private void handleCall(FunVal fun) {
        resetArgCount();
        fun.run(this);
    }

    // }}}1
    // State.CONSUME {{{1

    /**
     * Transitions to CONSUME state.
     */
    void transitionToResult(Val result) {
        transition(State.CONSUME, result);
    }

    /**
     * Transitions to CONSUME to raise an exception on the trace.
     *
     * Called from generated fun.
     */
    void transitionToRaiseOn(String msg, Trace trace) {
        transitionToRaiseException(vm.exception.of(msg, getTracesOn(trace)));
    }

    /**
     * Returns traces on the specified trace.
     */
    private List<TraceVal> getTracesOn(Trace trace) {
        List<TraceVal> traces = new ArrayList<>(this.callStack.traces().stream()
                .map(traceCse -> traceCse.toTraceVal(vm))
                .collect(Collectors.toList()));
        traces.add(trace.toTraceVal(this.vm));
        return traces;
    }

    /**
     * Fun to call exception(template.format(...args), traces).raise.
     */
    private FunVal makeRaiseFormatFun(
            List<? extends TraceVal> traces, String template, GraphNode... args) {
        return vm.fun.make("(raise-after-format)").action(c -> {
            var makeMessage = vm.fun.make().action(cc -> cc.call(vm.graph.format(template, args)));
            return c.call(makeMessage)
                .on((cc, message) -> cc.call(exceptionRaiseCaller(message, traces)));
        });
    }

    /**
     * Fun to call exception(message, traces).raise.
     */
    FunVal exceptionRaiseCaller(Val messageVal, List<? extends TraceVal> traces) {
        return vm.fun.make().action(c -> {
            if (! (messageVal instanceof StrVal messageStr)) {
                return c.call(vm.graph.raiseFormat(
                            "Str.format must return str, but got {}",
                            vm.graph.repr(messageVal)));
            }
            var exc = vm.exception.of(messageStr.string(), traces);
            return c.call(exc, raiseHandle);
        });
    }

    /**
     * Transitions to exception(template.format(...args) traces).new.
     */
    private void transitionToRaiseFormatOn(
            Trace trace, String template, GraphNode... args) {
        FunVal fun = makeRaiseFormatFun(getTracesOn(trace), template, args);
        this.dataStack.removeFromOffset(0);
        this.dataStack.push(vm.nada);
        transitionToCall(fun);
    }

    /**
     * Transitions to CONSUME to raise an exception
     * made from the Java throwable.
     */
    void transitionToRaiseThrowable(Throwable th) {
        var message = String.format(Locale.ROOT, "java exception: %s", th);
        var exc = vm.exception.of(message, getTraces()).chain(vm.exception.of(th));
        transitionToRaiseException(exc);
    }

    /**
     * Transitions to CONSUME to raise an exception by a message.
     */
    void transitionToRaise(String msg) {
        var exc = vm.exception.of(msg, getTraces());
        transitionToRaiseException(exc);
    }

    /**
     * Transitions to CONSUME to raise an exception .
     *
     * Implementation note:
     * This method must be implemented so that no stack entry is consumed.
     * This is because an exception can be a result of stack exhaustion,
     * and if transitionToRaiseWithTraces consumes a stack entry,
     * it can cause further stack exhaustion which results in an infinite loop.
     */
    void transitionToRaiseException(ExceptionVal exc) {
        if (this.callStack.canAbort(vm.escapeKontTag)) {
            CallStackSlice callSlice = this.callStack.abort(vm.escapeKontTag);
            int dataStackUsage = callSlice.dataStackUsage();
            this.dataStack.decreaseBp(dataStackUsage);
            this.dataStack.removeFromOffset(0);

            FunVal switchFun = vm.fun.make().take(2).action(c -> {
                Val onRaised = c.arg(0);
                if (! (onRaised instanceof FunVal onRaisedFun)) {
                    return c.call(vm.graph.raiseFormat(
                                "expected fun, but got {}",
                                vm.graph.repr(onRaised)));
                }
                return c.call(onRaisedFun).args(exc);
            });
            transitionToResult(switchFun);
        } else {
            transitionToTerminatedWithException(exc);
        }
    }

    /**
     * Handles the val resulted from a call.
     */
    private void handleResult(Val result) {
        ResumeCse frame = this.callStack.popResumer();
        long lnum = this.callStack.poppedLnum();
        this.dataStack.removeFromOffset(0);
        this.dataStack.decreaseBp(Lnums.getDataStackUsage(lnum));
        this.argCount = Lnums.getArgCount(lnum);

        if (frame instanceof GeneratedFunValBase fun) {
            this.dataStack.pop(); // pop the called fun
            fun.resume(this, result, Lnums.getProgramCounter(lnum));
        } else if (frame instanceof HostResumeCse f) {
            HostResultCore hostResult;
            try {
                hostResult = f.handler().reaction(nextCallContext(), result)
                    .makeHostResultCore();
            } catch (Throwable th) {
                transitionToRaiseThrowable(th);
                return;
            }
            transitionToHostResult(hostResult);
        } else {
            transitionToTerminateWithVal(result);
        }
    }

    // }}}1
    // State.HOST_RESULT {{{1

    /**
     * Transitions to HOST_RESULT state.
     */
    void transitionToHostResult(HostResultCore hostResult) {
        transition(State.HOST_RESULT, hostResult);
    }

    /**
     * Handles a reaction.
     */
    private void handleHostResult(HostResultCore hostResult) {
        hostResult.doTransition(this);
    }

    // }}}1
    // State.QBLOCK {{{1

    /**
     * Transition to cannot-spread exception.
     */
    void transitionToCannotSpread(Trace trace, Val vecExpected) {
        transitionToRaiseFormatOn(
                trace,
                "cannot spread non-vec val: {}",
                vm.graph.repr(vecExpected));
    }

    /**
     * Raises no such var error.
     */
    void transitionToRaiseNoSuchVar(String sym, Val owner, Trace trace) {
        transitionToRaiseFormatOn(trace,
                "no such var: {} not found in {}",
                vm.graph.of(vm.str.of(sym)),
                vm.graph.repr(owner));
    }

    /**
     * Tail-calls a fun.
     *
     * @param trace the trace of the call.
     * @param argCountToPass the count of args to pass.
     */
    void tailCallFun(Trace trace, int argCountToPass) {
        int argsOffset = dataStack.topOffset() - argCountToPass;
        int recvOffset = argsOffset - 1;
        int funOffset = recvOffset - 1;
        FunVal fun = (FunVal) this.dataStack.atOffset(funOffset);

        this.dataStack.removeToOffset(recvOffset);
        this.callStack.pushTailTrace(trace);
        transitionToCall(fun);
    }

    /**
     * Calls a fun.
     *
     * @param resumeCse cse to resume the current continuation.
     * @param resumeProgramCounter where the execution should be resumed.
     * @param argCountToPass the count of args to pass.
     */
    void callFun(ResumeCse resumeCse, int resumeProgramCounter, int argCountToPass) {
        int argsOffset = dataStack.topOffset() - argCountToPass;
        int recvOffset = argsOffset - 1;
        int funOffset = recvOffset - 1;
        FunVal fun = (FunVal) this.dataStack.atOffset(funOffset);
        this.dataStack.setAtOffset(funOffset, vm.nada);

        if (! this.callStack.isModerateSize()) {
            transitionToRaise("call stack overflow");
            return;
        }
        this.callStack.pushCse(resumeCse, resumeProgramCounter, this.argCount, recvOffset);
        this.dataStack.increaseBp(recvOffset);
        transitionToCall(fun);
    }

    /**
     * Raises not-a-fun error.
     */
    void raiseNotFun(Val funExpected, String sym, Trace errorTrace) {
        transitionToRaiseFormatOn(errorTrace,
                "not a fun: ${} is {}",
                vm.graph.of(vm.str.of(sym)),
                vm.graph.repr(funExpected));
    }

    /**
     * Rases wrong-number-of-args error.
     */
    void raiseWrongNumberOfArgs(int paramCount, String paramsRepr, VecVal args) {
        FunVal fun = ArgsSupport.wrongNumberOfArgs(
                vm, paramCount, vm.graph.of(vm.str.of(paramsRepr)), args);
        this.dataStack.removeFromOffset(0);
        this.dataStack.push(vm.nada);
        transitionToCall(fun);
    }

    /**
     * Raises not-vec-rhs error.
     */
    void raiseNotVecRhs(Val rhs) {
        FunVal fun = ArgsSupport.raiseNotVecRhs(vm, rhs);
        this.dataStack.removeFromOffset(0);
        this.dataStack.push(vm.nada);
        transitionToCall(fun);
    }

    // }}}1

    // Call context and flow {{{1

    /** Index of the call context, which may be currently active. */
    private int currentContextIndex = 0;

    /** Call context instances used in the stack machine. */
    private final CallContext[] callContexts = new CallContext[13];

    {
        for (int i = 0; i < this.callContexts.length; ++ i) {
            this.callContexts[i] = new RoundRobinCallContext(i);
        }
    }

    /**
     * Returns the next call context.
     */
    CallContext nextCallContext() {
        this.currentContextIndex = (this.currentContextIndex + 1) % this.callContexts.length;
        return this.callContexts[this.currentContextIndex];
    }

    /**
     * Context call implementation.
     *
     * Use different instances for each call,
     * to detect leak.
     * Leak detection is not perfect,
     * because the instance is reused cyclically,
     * and does not check the current thread.
     */
    private class RoundRobinCallContext implements CallContext {

        /** The index of the call context instance. */
        private final int index;

        /**
         * Constructs a call context.
         */
        RoundRobinCallContext(int index) {
            this.index = index;
        }

        /**
         * Checks which the call context is used outside of the make or handler
         * to which the context is passed.
         */
        private void checkContextLeak() {
            Preconds.checkState(this.index == StackMachine.this.currentContextIndex,
                    "context leak: CallContext instance used outside of the make or handler");
        }

        @Override
        public Val recv() {
            checkContextLeak();
            return StackMachine.this.getRecv();
        }

        /**
         * Returns the argCount.
         */
        @Override
        public int argCount() {
            checkContextLeak();
            return StackMachine.this.getArgCount();
        }

        @Override
        public Val arg(int argIndex) {
            checkContextLeak();
            return StackMachine.this.getArg(argIndex);
        }

        @Override
        public List<TraceVal> traces() {
            checkContextLeak();
            return StackMachine.this.getTraces();
        }

        @Override
        public HostResult raise(String msg) {
            checkContextLeak();
            return new RaiseMessageResult(msg);
        }

        @Override
        public HostResult raise(Throwable throwable) {
            checkContextLeak();
            return new RaiseThrowableResult(throwable);
        }

        @Override
        public CallFlowToRecv call(FunVal fun) {
            checkContextLeak();
            StackMachine.this.symHandleInFlow = StackMachine.this.callByHostHandle;
            StackMachine.this.funInFlow = fun;
            StackMachine.this.recvInFlow = vm.nada;
            return StackMachine.this.flowWithFunRecv;
        }

        @Override
        public CallFlowToRecv call(Val owner, int symHandle) {
            checkContextLeak();
            StackMachine.this.symHandleInFlow = symHandle;
            Val funExpected = owner.getVar(symHandle);
            if (funExpected instanceof FunVal) {
                StackMachine.this.funInFlow = (FunVal) funExpected;
            } else if (funExpected == null) {
                StackMachine.this.funInFlow = makeRaiseFormatFun(
                        traces(), "no such var: ${} not found in {}",
                        vm.graph.of(vm.str.of(vm.sym.symFor(symHandle))), vm.graph.repr(owner));
            } else {
                // if owner.repr is not a fun, you cannot call owner.repr
                String sym = vm.sym.symFor(symHandle);
                StackMachine.this.funInFlow = sym.equals("repr")
                    ? makeRaiseFormatFun(
                            traces(), "not a fun: $repr is {}",
                            vm.graph.repr(funExpected))
                    : makeRaiseFormatFun(
                            traces(), "not a fun: ${} in {} is {}",
                            vm.graph.of(vm.str.of(vm.sym.symFor(symHandle))),
                            vm.graph.repr(owner),
                            vm.graph.repr(funExpected));
            }
            StackMachine.this.recvInFlow = owner;
            return StackMachine.this.flowWithFunRecv;
        }

        @Override
        public CallFlowToArgs call(String modName, int symHandle) {
            Val loadedMod = vm.mod.getLoaded(modName);
            return loadedMod != null
                ? call(loadedMod, symHandle).recv(vm.nada)
                : new RequireThenInvokeCallFlow(vm, this, modName, symHandle, List.of());
        }

        @Override
        public CallFlowToOn call(GraphNode graph) {
            FunVal evaluator = vm.fun.make().action(graph::evaluateIn);
            return this.call(evaluator);
        }

    }

    /**
     * Returns the recv of the frame.
     */
    Val getRecv() {
        return this.dataStack.recv();
    }

    /**
     * Returns the argCount.
     */
    int getArgCount() {
        return this.argCount;
    }

    /**
     * Returns the specified arg of the frame.
     */
    Val getArg(int argIndex) {
        return this.dataStack.arg(argIndex);
    }

    /**
     * Returns the traces.
     */
    List<TraceVal> getTraces() {
        return new ArrayList<>(this.callStack.traces().stream()
                .map(traceCse -> traceCse.toTraceVal(vm))
                .collect(Collectors.toList()));
    }

    /** The recv set in the call flow. */
    @Nullable
    private Val recvInFlow;

    /** The fun set in the call flow. */
    @Nullable
    private FunVal funInFlow;

    /** The sym handle set in the call flow. */
    private int symHandleInFlow;

    /** Call flow with a fun and a recv. */
    final WithFunRecv flowWithFunRecv = makeFlowWithFunRecv(new WithFunRecv());

    /**
     * Provided to be overridden by test code.
     */
    WithFunRecv makeFlowWithFunRecv(WithFunRecv flow) {
        return flow;
    }

    /**
     * Call flow with a fun and a recv.
     */
    class WithFunRecv implements CallFlowToRecv {

        @Override
        public CallFlowToArgs recv(Val recv) {
            StackMachine.this.recvInFlow = recv;
            return this;
        }

        /**
         * Prepares for passing args.
         */
        @Nullable
        CallFlowToOn prepareArgs(int arity) {
            int recv = 1;
            if (! dataStack.ensureCapaSpPlus(recv + arity)) {
                StackMachine.this.funInFlow = null;
                StackMachine.this.recvInFlow = null;
                return new RaiseMessageResult("data stack overflow");
            }
            dataStack.push(StackMachine.this.recvInFlow);
            StackMachine.this.recvInFlow = null;
            return null;
        }

        @Override
        public CallFlowToOn args() {
            @Nullable CallFlowToOn aborted = prepareArgs(0);
            if (aborted != null) {
                return aborted;
            }

            return flowWithArgs;
        }

        @Override
        public CallFlowToOn args(Val arg0) {
            @Nullable CallFlowToOn aborted = prepareArgs(1);
            if (aborted != null) {
                return aborted;
            }

            dataStack.push(arg0);
            return flowWithArgs;
        }

        @Override
        public CallFlowToOn args(Val arg0, Val arg1) {
            @Nullable CallFlowToOn aborted = prepareArgs(2);
            if (aborted != null) {
                return aborted;
            }

            dataStack.push(arg0);
            dataStack.push(arg1);
            return flowWithArgs;
        }

        @Override
        public CallFlowToOn args(Val arg0, Val arg1, Val arg2) {
            @Nullable CallFlowToOn aborted = prepareArgs(3);
            if (aborted != null) {
                return aborted;
            }

            dataStack.push(arg0);
            dataStack.push(arg1);
            dataStack.push(arg2);
            return flowWithArgs;
        }

        @Override
        public CallFlowToOn args(Val arg0, Val arg1, Val arg2, Val arg3) {
            @Nullable CallFlowToOn aborted = prepareArgs(4);
            if (aborted != null) {
                return aborted;
            }

            dataStack.push(arg0);
            dataStack.push(arg1);
            dataStack.push(arg2);
            dataStack.push(arg3);
            return flowWithArgs;
        }

        @Override
        public CallFlowToOn args(Val arg0, Val arg1, Val arg2, Val arg3, Val arg4) {
            @Nullable CallFlowToOn aborted = prepareArgs(5);
            if (aborted != null) {
                return aborted;
            }

            dataStack.push(arg0);
            dataStack.push(arg1);
            dataStack.push(arg2);
            dataStack.push(arg3);
            dataStack.push(arg4);
            return flowWithArgs;
        }

        @Override
        public CallFlowToOn args(Val... args) {
            @Nullable CallFlowToOn aborted = prepareArgs(args.length);
            if (aborted != null) {
                return aborted;
            }

            for (Val arg : args) {
                dataStack.push(arg);
            }
            return flowWithArgs;
        }

        @Override
        public HostResult on(HostFunReaction retValHandler) {
            return args().on(retValHandler);
        }

        @Override
        public HostResultCore makeHostResultCore() {
            return args().makeHostResultCore();
        }

    }

    /** Call flow with a fun, recv and args. */
    final WithArgs flowWithArgs = new WithArgs();

    /**
     * Call flow with a fun, recv and args.
     */
    private class WithArgs extends HostResultCore implements CallFlowToOn, HostResult {
        @Override
        public HostResult on(HostFunReaction handler) {
            if (! callStack.isModerateSize()) {
                StackMachine.this.funInFlow = null;
                StackMachine.this.recvInFlow = null;
                return new RaiseMessageResult("call stack overflow");
            }

            HostResumeCse frame = new HostResumeCse(handler, StackMachine.this.symHandleInFlow);
            int dataStackUsage = StackMachine.this.argCount + 1;
            callStack.pushCse(frame, 0, StackMachine.this.argCount, dataStackUsage);
            int recvCount = 1;
            StackMachine.this.dataStack.increaseBp(recvCount + StackMachine.this.argCount);
            return CALL_REACTION;
        }

        @Override
        void doTransition(StackMachine stackMachine) {
            int recvCount = 1;
            dataStack.removeToOffset(recvCount + StackMachine.this.argCount);
            Trace trace = Trace.ofTail(stackMachine.symHandleInFlow);
            stackMachine.callStack.pushTailTrace(trace);
            CALL_REACTION.doTransition(stackMachine);
        }

        @Override
        public HostResultCore makeHostResultCore() {
            return this;
        }

    }

    /** HostResultCore to call the function set in the call flow. */
    private static final CallReaction CALL_REACTION = new CallReaction();

    /**
     * HostResultCore to call the function set in the call flow.
     */
    private static class CallReaction extends HostResultCore implements HostResult {

        @Override
        void doTransition(StackMachine stackMachine) {
            FunVal fun = stackMachine.funInFlow;
            stackMachine.funInFlow = null;
            stackMachine.transitionToCall(fun);
        }

        @Override
        public HostResultCore makeHostResultCore() {
            return this;
        }

    }

    // }}}1

}

// vim: et sw=4 sts=4 fdm=marker
