package org.kink_lang.kink;

import java.math.BigDecimal;
import java.math.BigInteger;
import java.math.RoundingMode;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
import java.util.OptionalInt;
import java.util.function.UnaryOperator;

import org.kink_lang.kink.internal.function.ThrowingFunction2;
import org.kink_lang.kink.internal.function.ThrowingFunction3;
import org.kink_lang.kink.internal.function.ThrowingFunction4;
import org.kink_lang.kink.hostfun.HostContext;
import org.kink_lang.kink.hostfun.HostFunBuilder;
import org.kink_lang.kink.hostfun.CallContext;
import org.kink_lang.kink.hostfun.HostResult;
import org.kink_lang.kink.internal.num.NumOperations;

/**
 * The helper for num vals.
 */
public class NumHelper {

    /** The vm. */
    private final Vm vm;

    /** The sym handle of "new". */
    private int newHandle;

    /** The sym handle of "_show_num". */
    private int showNumHandle;

    /** Shared vars of num vals. */
    SharedVars sharedVars;

    /**
     * Constructs the helper.
     */
    NumHelper(Vm vm) {
        this.vm = vm;
    }

    /**
     * Returns a num val representing the specified number.
     *
     * @param bigDecimal a BigDecimal number.
     * @return a num val representing the specified number.
     * @throws IllegalArgumentException when the scale is negative.
     */
    public NumVal of(BigDecimal bigDecimal) {
        return new NumVal(vm, bigDecimal);
    }

    /**
     * Returns a num val representing the specified number.
     *
     * @param intNum an int number.
     * @return a num val representing the specified number.
     */
    public NumVal of(int intNum) {
        return new NumVal(vm, BigDecimal.valueOf(intNum));
    }

    /**
     * Returns a num val representing the specified number.
     *
     * @param longNum a long number.
     * @return a num val representing the specified number.
     */
    public NumVal of(long longNum) {
        return new NumVal(vm, BigDecimal.valueOf(longNum));
    }

    /**
     * Returns a num val representing the specified number.
     *
     * @param bigInt a BigInteger number.
     * @return a num val representing the specified number.
     */
    public NumVal of(BigInteger bigInt) {
        return new NumVal(vm, new BigDecimal(bigInt));
    }

    /**
     * Initializes the helper.
     */
    void init() {
        this.newHandle = vm.sym.handleFor("new");
        this.showNumHandle = vm.sym.handleFor("_show_num");

        Map<Integer, Val> vars = new HashMap<>();
        addUnaryOp(vars, "Num", "mantissa", (c, n, d) -> vm.num.of(n.unscaledValue()));
        addUnaryOp(vars, "Num", "scale", (c, n, d) -> vm.num.of(n.scale()));
        addUnaryOp(vars, "Num", "int?", (c, n, d) -> vm.bool.of(n.scale() == 0));
        addBinaryOp(vars, "Num", "op_add", "Arg_num", (c, x, y, desc) -> vm.num.of(x.add(y)));
        addBinaryOp(vars, "Num", "op_sub", "Arg_num", (c, x, y, desc) -> vm.num.of(x.subtract(y)));
        addBinaryOp(vars, "Num", "op_mul", "Arg_num", (c, x, y, desc) -> vm.num.of(x.multiply(y)));
        addBinaryOp(vars, "Num", "op_div", "Divisor", this::opDivMethod);
        addBinaryOp(vars, "Num", "op_intdiv", "Divisor", this::opIntDivMethod);
        addBinaryOp(vars, "Num", "op_rem", "Divisor", this::opRemMethod);
        addUnaryOp(vars, "Num", "op_minus", (c, n, d) -> vm.num.of(n.negate()));
        addBinaryIntOp(vars, "Num", "op_or", "Arg_num", (c, x, y) -> vm.num.of(x.or(y)));
        addBinaryIntOp(vars, "Num", "op_xor", "Arg_num", (c, x, y) -> vm.num.of(x.xor(y)));
        addBinaryIntOp(vars, "Num", "op_and", "Arg_num", (c, x, y) -> vm.num.of(x.and(y)));
        addUnaryIntOp(vars, "Num", "op_not", (c, x) -> vm.num.of(x.not()));
        addShiftMethod(vars, "Num", "op_shl", "Bit_count", BigInteger::shiftLeft);
        addShiftMethod(vars, "Num", "op_shr", "Bit_count", BigInteger::shiftRight);
        addMethod(vars, "Num", "round", "", 0, this::roundMethod);
        addUnaryOp(vars, "Num", "abs", (c, n, d) -> vm.num.of(n.abs()));
        addBinaryOp(vars, "Num1", "op_eq", "Num2", (c, x, y, d) -> vm.bool.of(x.compareTo(y) == 0));
        addBinaryOp(vars, "Num1", "op_lt", "Num2", (c, x, y, d) -> vm.bool.of(x.compareTo(y) < 0));
        addUnaryOp(vars, "Num", "times", this::timesMethod);
        addUnaryOp(vars, "Num", "up", this::upMethod);
        addUnaryOp(vars, "Num", "down", this::downMethod);
        addUnaryOp(vars, "Num", "repr", this::reprMethod);
        addMethod(vars, "Num", "show", "(...[$config])", f -> f.takeMinMax(0, 1), this::showMethod);
        this.sharedVars = vm.sharedVars.of(vars);
    }

    // Numer.op_div(Denom) {{{1

    /**
     * Implementation of Numer.op_div(Denom).
     */
    private HostResult opDivMethod(
            CallContext c, BigDecimal dividend, BigDecimal divisor, String desc) {
        if (divisor.signum() == 0) {
            return c.raise(String.format(Locale.ROOT,
                        "%s: zero division: %s is divided by 0", desc, dividend.toPlainString()));
        }
        return c.call("kink/NUM_DIV", newHandle).args(c.recv(), c.arg(0));
    }

    // }}}1

    // Numer.op_intdiv(Denom) {{{1

    /**
     * Implementation of Numer.op_intdiv(Denom).
     */
    private HostResult opIntDivMethod(
            HostContext c, BigDecimal numer, BigDecimal denom, String desc) {
        if (denom.signum() == 0) {
            return c.raise(String.format(Locale.ROOT,
                        "%s: zero division: %s is divided by 0",
                        desc,
                        numer.toPlainString()));
        }
        RoundingMode mode = denom.signum() < 0 ? RoundingMode.CEILING : RoundingMode.FLOOR;
        return vm.num.of(numer.divide(denom, 0, mode));
    }

    // }}}1

    // Numer.op_rem(Denom) {{{1

    /**
     * Implementation of Numer.op_rem(Denom).
     */
    private HostResult opRemMethod(HostContext c, BigDecimal numer, BigDecimal denom, String desc) {
        if (denom.signum() == 0) {
            return c.raise(String.format(Locale.ROOT,
                        "%s: zero division: %s is divided by 0",
                        desc,
                        numer.toPlainString()));
        }
        RoundingMode mode = denom.signum() < 0 ? RoundingMode.CEILING : RoundingMode.FLOOR;
        BigDecimal quot = numer.divide(denom, 0, mode);
        BigDecimal rem = numer.subtract(quot.multiply(denom));
        return vm.num.of(rem);
    }

    // }}}1

    // Num.round {{{1

    /**
     * Implementation of Num.round.
     */
    private HostResult roundMethod(CallContext c, NumVal num, String desc) {
        return c.call("kink/NUM_DIV", newHandle).args(c.recv(), vm.num.of(1));
    }

    // }}}1

    // Num.times {{{1

    /**
     * Implementation of Num.times.
     */
    private HostResult timesMethod(HostContext c, BigDecimal count, String desc) {
        if (count.scale() != 0 || count.signum() < 0) {
            return c.call(vm.graph.raiseFormat(
                        "{}: Num must be a nonnegative int num, but got {}",
                        vm.graph.of(vm.str.of(desc)),
                        vm.graph.of(vm.num.of(count))));
        }
        return c.call("kink/iter/ITER", newHandle).args(timesIfun(BigDecimal.ZERO, count));
    }

    /**
     * Ifun of Num.times.
     */
    private FunVal timesIfun(BigDecimal ind, BigDecimal count) {
        return makeIfun("Num.times-ifun", (c, proc, fin) -> ind.compareTo(count) < 0
                ? c.call(proc).args(vm.num.of(ind), timesIfun(ind.add(BigDecimal.ONE), count))
                : c.call(fin));
    }

    // }}}1
    // Num.up and .down {{{1

    /**
     * Implementation of Num.up.
     */
    private HostResult upMethod(HostContext c, BigDecimal start, String desc) {
        return c.call("kink/iter/ITER", newHandle)
            .args(upDownIfun("Num.up-ifun", start, BigDecimal.ONE));
    }

    /**
     * Implementation of Num.down.
     */
    private HostResult downMethod(HostContext c, BigDecimal start, String desc) {
        return c.call("kink/iter/ITER", newHandle)
            .args(upDownIfun("Num.down-ifun", start, BigDecimal.valueOf(-1)));
    }

    /**
     * Ifun of Num.up and Num.down.
     */
    private FunVal upDownIfun(String prefix, BigDecimal num, BigDecimal delta) {
        return makeIfun(prefix, (c, proc, fin) ->
                c.call(proc).args(vm.num.of(num), upDownIfun(prefix, num.add(delta), delta)));
    }

    // }}}1
    // Num.repr {{{1

    /**
     * Implementation of Num.repr.
     */
    private HostResult reprMethod(HostContext c, BigDecimal num, String desc) {
        return vm.str.of(num.scale() < 0
            ? reprWithScale(num)
            : reprBasedOnPlain(num));
    }

    /**
     * Makes repr representation based on the plain string.
     */
    private String reprBasedOnPlain(BigDecimal num) {
        return num.signum() < 0
            ? String.format(Locale.ROOT, "(%s)", num.toPlainString())
            : num.toPlainString();
    }

    /**
     * Makes repr representation with scale.
     */
    private String reprWithScale(BigDecimal num) {
        BigInteger mantissa = num.unscaledValue();
        int scale = num.scale();
        return String.format(Locale.ROOT, "(num mantissa=%d scale=%d)", mantissa, scale);
    }

    // }}}1

    // Num.show(...[$config]) {{{1

    /**
     * Implementation of Num.show(...[$config]).
     */
    private HostResult showMethod(CallContext c, NumVal num, String desc) {
        Val config = c.argCount() == 0
            ? vm.fun.make().take(1).action(cc -> vm.nada)
            : c.arg(0);
        return c.call("kink/NUM", showNumHandle).args(num, config);
    }

    // }}}1

    /**
     * Adds an unary operator fun.
     */
    private void addUnaryOp(
            Map<Integer, Val> vars,
            String recvDesc,
            String sym,
            ThrowingFunction3<CallContext, BigDecimal, String, HostResult> action) {
        addMethod(vars, recvDesc, sym, "", 0,
                (c, num, desc) -> action.apply(c, num.bigDecimal(), desc));
    }

    /**
     * Adds a binary operator method to vars.
     */
    private void addBinaryOp(
            Map<Integer, Val> vars,
            String recvDesc,
            String sym,
            String argDesc,
            ThrowingFunction4<CallContext, BigDecimal, BigDecimal, String, HostResult> action) {
        addMethod(vars, recvDesc, sym, "(" + argDesc + ")", 1, (c, recv, desc) -> {
            Val arg = c.arg(0);
            if (! (arg instanceof NumVal argNum)) {
                return c.call(vm.graph.raiseFormat("{}: {} must be a num, but was {}",
                            vm.graph.of(vm.str.of(desc)),
                            vm.graph.of(vm.str.of(argDesc)),
                            vm.graph.repr(arg)));
            }
            return action.apply(c, recv.bigDecimal(), argNum.bigDecimal(), desc);
        });
    }

    /**
     * Adds an unary int operator to vars.
     */
    private void addUnaryIntOp(
            Map<Integer, Val> vars,
            String recvDesc,
            String sym,
            ThrowingFunction2<CallContext, BigInteger, HostResult> op) {
        addMethod(vars, recvDesc, sym, "", 0, (c, recv, desc) -> {
            BigInteger recvInt = exactBigInt(recv);
            if (recvInt == null) {
                return c.call(vm.graph.raiseFormat("{}: {} must be an int num, but was {}",
                            vm.graph.of(vm.str.of(desc)),
                            vm.graph.of(vm.str.of(recvDesc)),
                            vm.graph.repr(recv)));
            }
            return op.apply(c, recvInt);
        });
    }

    /**
     * Adds a binary int operator to vars.
     */
    private void addBinaryIntOp(
            Map<Integer, Val> vars,
            String recvDesc,
            String sym,
            String argDesc,
            ThrowingFunction3<CallContext, BigInteger, BigInteger, HostResult> op) {
        addMethod(vars, recvDesc, sym, "(" + argDesc + ")", 1, (c, recv, desc) -> {
            BigInteger recvInt = exactBigInt(recv);
            if (recvInt == null) {
                return c.call(vm.graph.raiseFormat("{}: {} must be an int num, but was {}",
                            vm.graph.of(vm.str.of(desc)),
                            vm.graph.of(vm.str.of(recvDesc)),
                            vm.graph.repr(recv)));
            }

            Val a0 = c.arg(0);
            BigInteger argInt = exactBigInt(a0);
            if (argInt == null) {
                return c.call(vm.graph.raiseFormat("{}: {} must be an int num, but was {}",
                            vm.graph.of(vm.str.of(desc)),
                            vm.graph.of(vm.str.of(argDesc)),
                            vm.graph.repr(a0)));
            }

            return op.apply(c, recvInt, argInt);
        });
    }

    /**
     * Adds a fixed arity method to vars.
     */
    private void addMethod(
            Map<Integer, Val> vars,
            String recvDesc,
            String sym,
            String argsDesc,
            int arity,
            ThrowingFunction3<CallContext, NumVal, String, HostResult> action) {
        addMethod(vars, recvDesc, sym, argsDesc, f -> f.take(arity), action);
    }

    /**
     * Adds a method to vars.
     */
    private void addMethod(
            Map<Integer, Val> vars,
            String recvDesc,
            String sym,
            String argsDesc,
            UnaryOperator<HostFunBuilder> configFun,
            ThrowingFunction3<CallContext, NumVal, String, HostResult> action) {
        String desc = String.format(Locale.ROOT, "%s.%s%s", recvDesc, sym, argsDesc);
        var fun = configFun.apply(vm.fun.make(desc)).action(c -> {
            Val recv = c.recv();
            if (! (recv instanceof NumVal recvNum)) {
                return c.call(vm.graph.raiseFormat("{}: {} must be a num, but was {}",
                            vm.graph.of(vm.str.of(desc)),
                            vm.graph.of(vm.str.of(recvDesc)),
                            vm.graph.repr(recv)));
            }
            return action.apply(c, recvNum, desc);
        });
        vars.put(vm.sym.handleFor(sym), fun);
    }

    /**
     * Adds a shift method to vars.
     */
    private void addShiftMethod(
            Map<Integer, Val> vars,
            String recvDesc,
            String sym,
            String argDesc,
            ThrowingFunction2<BigInteger, Integer, BigInteger> action) {
        addMethod(vars, recvDesc, sym, "(" + argDesc + ")", 1, (c, recv, desc) -> {
            BigInteger recvInt = exactBigInt(recv);
            if (recvInt == null) {
                return c.call(vm.graph.raiseFormat("{}: {} must be an int num, but was {}",
                            vm.graph.of(vm.str.of(desc)),
                            vm.graph.of(vm.str.of(recvDesc)),
                            vm.graph.repr(recv)));
            }

            Val shift = c.arg(0);
            OptionalInt shiftInt = NumOperations.getExactInt(shift);
            if (! shiftInt.isPresent()) {
                return c.call(vm.graph.raiseFormat(
                            "{}: {} must be an int num between [{}, {}], but was {}",
                            vm.graph.of(vm.str.of(desc)),
                            vm.graph.of(vm.str.of(argDesc)),
                            vm.graph.of(vm.num.of(Integer.MIN_VALUE)),
                            vm.graph.of(vm.num.of(Integer.MAX_VALUE)),
                            vm.graph.repr(shift)));
            }
            return vm.num.of(action.apply(recvInt, shiftInt.getAsInt()));
        });
    }

    /**
     * Extracts an exact BigInteger from val, or null.
     */
    private BigInteger exactBigInt(Val val) {
        if (! (val instanceof NumVal num)) {
            return null;
        }

        BigDecimal dec = num.bigDecimal();
        if (dec.scale() != 0) {
            return null;
        }

        return dec.toBigInteger();
    }

    /**
     * Makes an ifun.
     */
    private FunVal makeIfun(
            String prefix,
            ThrowingFunction3<CallContext, FunVal, FunVal, HostResult> action) {
        return vm.fun.make("Num.times-ifun").take(2).action(c -> {
            Val proc = c.arg(0);
            if (! (proc instanceof FunVal)) {
                return c.call(vm.graph.raiseFormat("{}: required fun as $proc, but got {}",
                            vm.graph.of(vm.str.of(prefix)),
                            vm.graph.repr(proc)));
            }

            Val fin = c.arg(1);
            if (! (fin instanceof FunVal)) {
                return c.call(vm.graph.raiseFormat("{}: required fun as $fin, but got {}",
                            vm.graph.of(vm.str.of(prefix)),
                            vm.graph.repr(fin)));
            }
            return action.apply(c, (FunVal) proc, (FunVal) fin);
        });
    }

}

// vim: et sw=4 sts=4 fdm=marker
