package org.kink_lang.kink;

import java.lang.invoke.MethodHandles;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
import java.util.OptionalInt;
import java.util.function.Function;

import org.kink_lang.kink.hostfun.CallContext;
import org.kink_lang.kink.hostfun.HostFunBuilder;
import org.kink_lang.kink.hostfun.HostResult;
import org.kink_lang.kink.internal.compile.ItreeCompiler;
import org.kink_lang.kink.internal.function.ThrowingFunction2;
import org.kink_lang.kink.internal.program.lex.Lexer;
import org.kink_lang.kink.internal.program.ast.Parser;
import org.kink_lang.kink.internal.program.itree.NodeToItreeTranslator;
import org.kink_lang.kink.internal.program.itreeoptimize.ItreeOptimizers;

/**
 * A helper for {@linkplain FunVal funs}.
 *
 * @see Vm#fun
 */
public class FunHelper {

    /** The vm. */
    private final Vm vm;

    /** Shared vars of fun vals. */
    SharedVars sharedVars;

    /** Result of make(). */
    private final HostFunBuilder defaultHostFunBuilder;

    /**
     * Constructs the helper.
     */
    FunHelper(Vm vm) {
        this.vm = vm;
        this.defaultHostFunBuilder = new HostFunBuilderImpl(vm, "(fun)", 0, OptionalInt.empty());
    }

    // compile {{{1

    /**
     * Compiles a program text to a top level fun.
     *
     * <p>If the program text is successfully compiled to a top level fun,
     * {@code onSucc} is called with the fun.</p>
     *
     * <p>If the program text cannot be compiled and causes a compile error,
     * {@code onError} is called with the compile error.</p>
     *
     * @param locale the locale to generate compile error messages.
     * @param programName the name of the program.
     * @param programText the program text.
     * @param binding the binding of the top level fun.
     * @param onSucc the function which is called when the program is successfully compiled.
     * @param onError the function which is called when the program causes a compile error.
     * @param <T> the type of the result.
     * @return the result of onSucc or onError.
     */
    public <T> T compile(
            Locale locale, String programName, String programText,
            BindingVal binding,
            Function<? super FunVal, ? extends T> onSucc,
            Function<? super CompileError, ? extends T> onError) {
        Lexer lexer = new Lexer(locale);
        Parser parser = new Parser(locale);
        ItreeCompiler compiler
            = new ItreeCompiler(vm, MethodHandles.lookup(), programName, programText, binding);
        return parser.parse(lexer.apply(programText),
                new NodeToItreeTranslator()
                .andThen(ItreeOptimizers.getOptimizer())
                .andThen(compiler::compile)
                .andThen(onSucc),
                (msg, fromPos, toPos) -> {
                    var from = vm.location.of(programName, programText, fromPos);
                    var to = vm.location.of(programName, programText, toPos);
                    var error = new CompileError(msg, from, to);
                    return onError.apply(error);
                });
    }

    // }}}1

    // host fun builder {{{1

    /**
     * Returns a builder of a host fun.
     *
     * <p>Calling this method is equivalent to calling
     * {@code make("(fun)")}.</p>
     *
     * <p>See {@link org.kink_lang.kink.hostfun} for detail.</p>
     *
     * @return a builder.
     */
    public HostFunBuilder make() {
        return defaultHostFunBuilder;
    }

    /**
     * Returns a builder of a host fun, with the description of the fun.
     *
     * <p>See {@link org.kink_lang.kink.hostfun} for detail.</p>
     *
     * @param desc the description of the fun, which is used in arity error messages.
     * @return a builder.
     */
    public HostFunBuilder make(String desc) {
        return new HostFunBuilderImpl(vm, desc, 0, OptionalInt.empty());
    }

    // }}}1

    // constant {{{1

    /**
     * Returns a constant fun: which is a 0-ary fun returning the val constantly.
     *
     * @param val the constant val.
     * @return a constant fun.
     */
    public FunVal constant(Val val) {
        return make("(constant)").take(0).action(c -> val);
    }

    // }}}1

    /**
     * Initializes the helper.
     */
    void init() {
        Map<Integer, Val> vars = new HashMap<>();
        vars.put(vm.sym.handleFor("call"),
                vm.fun.make("Fun.call(Recv Args)").take(2).action(this::callMethod));
        vars.put(vm.sym.handleFor("repr"),
                method0("Fun.repr", (c, fun) -> vm.str.of(fun.getRepr())));

        this.sharedVars = vm.sharedVars.of(vars);
    }

    /**
     * Makes a nullary method fun.
     */
    private FunVal method0(
            String prefix,
            ThrowingFunction2<CallContext, FunVal, HostResult> action) {
        return vm.fun.make(prefix).take(0).action(c -> {
            Val recv = c.recv();
            if (! (recv instanceof FunVal)) {
                return c.call(vm.graph.raiseFormat("{}: required fun for Fun but got {}",
                            vm.graph.of(vm.str.of(prefix)),
                            vm.graph.repr(recv)));
            }
            return action.apply(c, (FunVal) recv);
        });
    }

    /**
     * Fun.call.
     */
    private HostResult callMethod(CallContext c) {
        if (! (c.recv() instanceof FunVal fun)) {
            return c.call(vm.graph.raiseFormat(
                        "Fun.call(Recv Args): Fun must be fun, but got {}",
                        vm.graph.repr(c.recv())));
        }
        Val recv = c.arg(0);
        if (! (c.arg(1) instanceof VecVal args)) {
            return c.call(vm.graph.raiseFormat(
                        "Fun.call(Recv Args): Args must be vec, but got {}",
                        vm.graph.repr(c.arg(1))));
        }
        return c.call(fun).recv(recv).args(args.toList().toArray(Val[]::new));
    }

}

// vim: et sw=4 sts=4 fdm=marker
