package org.kink_lang.kink;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.HashMap;
import java.util.Map;

import javax.annotation.Nullable;

import org.kink_lang.kink.internal.function.ThrowingFunction2;
import org.kink_lang.kink.internal.function.ThrowingFunction3;
import org.kink_lang.kink.hostfun.CallContext;
import org.kink_lang.kink.hostfun.HostResult;

/**
 * The helper of excetpion vals.
 */
public class ExceptionHelper {

    /** The vm. */
    private final Vm vm;

    /** Shared vars of exception vals. */
    SharedVars sharedVars;

    /**
     * Constructs a helper.
     */
    ExceptionHelper(Vm vm) {
        this.vm = vm;
    }

    /**
     * Initializes the helper.
     */
    void init() {
        this.descIterHandle = vm.sym.handleFor("_desc_iter");
        Map<Integer, Val> vars = new HashMap<>();
        addMethod0(vars, "Exception", "message", (c, exc) -> vm.str.of(exc.message()));
        addMethod0(vars, "Exception", "traces", (c, exc) -> vm.vec.of(exc.traces()));
        addMethod0(vars, "Exception", "have_next?", (c, exc) -> vm.bool.of(exc.next().isPresent()));
        addMethod0(vars, "Exception", "next", this::nextMethod);
        addMethod0(vars, "Exception", "raise", this::raiseMethod);
        addBinaryOp(vars, "Exception", "op_add", "Tail", this::opAddMethod);
        addMethod0(vars, "Exception", "desc_iter", this::descIterMethod);
        addMethod0(vars, "Exception", "repr", this::reprMethod);
        addBinaryOp(vars, "Exception", "op_eq", "Arg_exception", this::opEqMethod);
        this.sharedVars = vm.sharedVars.of(vars);
    }

    /**
     * Makes an exception val without chaining.
     *
     * @param message the exception message.
     * @param traces the exception traces.
     * @return an exception val without chaining.
     */
    public ExceptionVal of(String message, List<? extends TraceVal> traces) {
        return new ExceptionVal(vm, message, traces);
    }

    /**
     * Makes an exception val converted from Java Throwable.
     *
     * @param th the Java Throwable.
     * @return an exception val.
     */
    public ExceptionVal of(Throwable th) {
        var throwables = flatten(th);
        @Nullable
        ExceptionVal tail = null;
        for (int i = throwables.size() - 1; i >= 0; -- i) {
            tail = ofFlat(throwables.get(i), tail);
        }
        return tail;
    }

    /**
     * Make an exception without traversing the causes and the suppressed.
     */
    ExceptionVal ofFlat(Throwable th, @Nullable ExceptionVal tail) {
        var stes = th.getStackTrace();
        List<TraceVal> traces = new ArrayList<>(stes.length);
        for (int i = stes.length - 1; i >= 0; -- i) {
            var ste = stes[i];
            traces.add(vm.trace.of(ste.toString()));
        }
        var head = of(th.toString(), traces);
        return tail == null ? head : head.chain(tail);
    }

    /**
     * Make a flat list of the causes and the suppressed.
     */
    static List<Throwable> flatten(Throwable th) {
        List<Throwable> accum = new ArrayList<>();
        flattenTraverse(th, accum);
        return Collections.unmodifiableList(accum);
    }

    /**
     * Traverse the exception tree to make a flat list.
     */
    private static void flattenTraverse(@Nullable Throwable th, List<Throwable> accum) {
        if (th == null || accum.contains(th)) {
            return;
        }

        accum.add(th);
        flattenTraverse(th.getCause(), accum);
        for (Throwable suppressed : th.getSuppressed()) {
            flattenTraverse(suppressed, accum);
        }
    }

    /**
     * Add a nullary method.
     */
    private void addMethod0(
            Map<Integer, Val> vars,
            String recvDesc,
            String methodName,
            ThrowingFunction2<CallContext, ExceptionVal, HostResult> action) {
        String desc = String.format(Locale.ROOT, "%s.%s", recvDesc, methodName);
        FunVal fun = vm.fun.make(desc).action(c -> {
            Val recv = c.recv();
            return recv instanceof ExceptionVal exc
                ? action.apply(c, exc)
                : c.call(vm.graph.raiseFormat("{}: {} must be exception, but got {}",
                            vm.graph.of(vm.str.of(desc)),
                            vm.graph.of(vm.str.of(recvDesc)),
                            vm.graph.repr(recv)));
        });
        vars.put(vm.sym.handleFor(methodName), fun);
    }

    /**
     * Add a binary operator method.
     */
    private void addBinaryOp(
            Map<Integer, Val> vars,
            String recvDesc,
            String methodName,
            String argDesc,
            ThrowingFunction3<CallContext, ExceptionVal, ExceptionVal, HostResult> action) {
        String desc = String.format(Locale.ROOT, "%s.%s(%s)", recvDesc, methodName, argDesc);
        FunVal fun = vm.fun.make(desc).action(c -> {
            Val recv = c.recv();
            if (! (recv instanceof ExceptionVal recvExc)) {
                return c.call(vm.graph.raiseFormat("{}: {} must be exception, but got {}",
                            vm.graph.of(vm.str.of(desc)),
                            vm.graph.of(vm.str.of(recvDesc)),
                            vm.graph.repr(recv)));
            }

            Val arg = c.arg(0);
            if (! (arg instanceof ExceptionVal argExc)) {
                return c.call(vm.graph.raiseFormat("{}: {} must be exception, but got {}",
                            vm.graph.of(vm.str.of(desc)),
                            vm.graph.of(vm.str.of(argDesc)),
                            vm.graph.repr(arg)));
            }

            return action.apply(c, recvExc, argExc);
        });
        vars.put(vm.sym.handleFor(methodName), fun);
    }

    /**
     * Exception.next method.
     */
    private HostResult nextMethod(CallContext c, ExceptionVal recv) {
        return recv.next()
            .map(next -> (HostResult) next)
            .orElseGet(() -> c.call(vm.graph.raiseFormat(
                            "Exception.next: no next exception for {}",
                            vm.graph.repr(recv))));
    }

    // raise {{{

    /**
     * Exception.raise method.
     */
    private HostResult raiseMethod(CallContext c, ExceptionVal recv) {
        return c.call(raiseDelegate(recv));
    }

    /**
     * Body of raise.
     */
    FunVal raiseDelegate(ExceptionVal exc) {
        return new FunVal(vm) {
            @Override void run(StackMachine stackMachine) {
                stackMachine.transitionToRaiseException(exc);
            }
        };
    }

    // }}}

    /** Sym handle of _desc_iter. */
    private int descIterHandle;

    /**
     * Exception.desc_iter.
     */
    private HostResult descIterMethod(CallContext c, ExceptionVal recv) {
        return c.call("kink/EXCEPTION", descIterHandle).args(recv);
    }

    /**
     * Exception.repr method.
     */
    private HostResult reprMethod(CallContext c, ExceptionVal recv) {
        var message = vm.graph.repr(vm.str.of(recv.message()));
        return c.call(recv.next()
            .map(next -> vm.graph.format("(exception {} {})", message, vm.graph.repr(next)))
            .orElseGet(() -> vm.graph.format("(exception {})", message)));
    }

    /**
     * Exception.op_eq(Arg_exception) method.
     */
    private HostResult opEqMethod(CallContext c, ExceptionVal recv, ExceptionVal arg) {
        return vm.bool.of(recv.equals(arg));
    }

    /**
     * Exception.op_add(Tail) method.
     */
    private HostResult opAddMethod(CallContext c, ExceptionVal left, ExceptionVal right) {
        return left.chain(right);
    }

}

// vim: et sw=4 sts=4 fdm=marker
