package org.kink_lang.kink;

import java.math.BigDecimal;
import java.math.BigInteger;
import java.math.RoundingMode;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
import java.util.OptionalInt;

import org.kink_lang.kink.internal.function.ThrowingFunction2;
import org.kink_lang.kink.internal.function.ThrowingFunction3;
import org.kink_lang.kink.hostfun.HostContext;
import org.kink_lang.kink.hostfun.CallContext;
import org.kink_lang.kink.hostfun.HostResult;
import org.kink_lang.kink.internal.num.NumOperations;

/**
 * The helper for num vals.
 */
public class NumHelper {

    /** The vm. */
    private final Vm vm;

    /** The sym handle of "new". */
    private int newHandle;

    /** The sym handle of "_show_num". */
    private int showNumHandle;

    /** Shared vars of num vals. */
    SharedVars sharedVars;

    /**
     * Constructs the helper.
     */
    NumHelper(Vm vm) {
        this.vm = vm;
    }

    /**
     * Returns a num val representing the specified number.
     *
     * @param bigDecimal a BigDecimal number.
     * @return a num val representing the specified number.
     * @throws IllegalArgumentException when the scale is negative.
     */
    public NumVal of(BigDecimal bigDecimal) {
        return new NumVal(vm, bigDecimal);
    }

    /**
     * Returns a num val representing the specified number.
     *
     * @param intNum an int number.
     * @return a num val representing the specified number.
     */
    public NumVal of(int intNum) {
        return new NumVal(vm, BigDecimal.valueOf(intNum));
    }

    /**
     * Returns a num val representing the specified number.
     *
     * @param longNum a long number.
     * @return a num val representing the specified number.
     */
    public NumVal of(long longNum) {
        return new NumVal(vm, BigDecimal.valueOf(longNum));
    }

    /**
     * Returns a num val representing the specified number.
     *
     * @param bigInt a BigInteger number.
     * @return a num val representing the specified number.
     */
    public NumVal of(BigInteger bigInt) {
        return new NumVal(vm, new BigDecimal(bigInt));
    }

    /**
     * Initializes the helper.
     */
    void init() {
        this.newHandle = vm.sym.handleFor("new");
        this.showNumHandle = vm.sym.handleFor("_show_num");

        Map<Integer, Val> vars = new HashMap<>();
        vars.put(vm.sym.handleFor("mantissa"), unaryOp("Num.mantissa",
                    (c, n) -> vm.num.of(n.unscaledValue())));
        vars.put(vm.sym.handleFor("scale"), unaryOp("Num.scale",
                    (c, n) -> vm.num.of(n.scale())));
        vars.put(vm.sym.handleFor("int?"), unaryOp("Num.int?",
                    (c, n) -> vm.bool.of(n.scale() == 0)));
        vars.put(vm.sym.handleFor("op_add"), binaryOp("Num.op_add",
                    (c, x, y) -> vm.num.of(x.add(y))));
        vars.put(vm.sym.handleFor("op_sub"), binaryOp("Num.op_sub",
                    (c, x, y) -> vm.num.of(x.subtract(y))));
        vars.put(vm.sym.handleFor("op_mul"), binaryOp("Num.op_mul",
                    (c, x, y) -> vm.num.of(x.multiply(y))));
        vars.put(vm.sym.handleFor("op_div"), binaryOp("Num.op_div", this::opDivMethod));
        vars.put(vm.sym.handleFor("op_intdiv"), binaryOp("Num.op_div", this::opIntDivMethod));
        vars.put(vm.sym.handleFor("op_rem"), binaryOp("Num.op_rem", this::opRemMethod));
        vars.put(vm.sym.handleFor("op_minus"), unaryOp("Num.op_minus",
                    (c, n) -> vm.num.of(n.negate())));
        vars.put(vm.sym.handleFor("op_or"), binaryIntOp("Num.op_or",
                    (c, x, y) -> vm.num.of(x.or(y))));
        vars.put(vm.sym.handleFor("op_xor"), binaryIntOp("Num.op_xor",
                    (c, x, y) -> vm.num.of(x.xor(y))));
        vars.put(vm.sym.handleFor("op_and"), binaryIntOp("Num.op_and",
                    (c, x, y) -> vm.num.of(x.and(y))));
        vars.put(vm.sym.handleFor("op_not"), unaryIntOp("Num.op_not",
                    (c, x) -> vm.num.of(x.not())));
        vars.put(vm.sym.handleFor("op_shl"), shiftMethod("Num.op_shl", BigInteger::shiftLeft));
        vars.put(vm.sym.handleFor("op_shr"), shiftMethod("Num.op_shr", BigInteger::shiftRight));
        vars.put(vm.sym.handleFor("round"), method0("Num.round", this::roundMethod));
        vars.put(vm.sym.handleFor("abs"), unaryOp("Num.abs",
                    (c, n) -> vm.num.of(n.abs())));
        vars.put(vm.sym.handleFor("op_eq"), binaryOp("Num.op_eq",
                    (c, x, y) -> vm.bool.of(x.compareTo(y) == 0)));
        vars.put(vm.sym.handleFor("op_lt"), binaryOp("Num.op_lt",
                    (c, x, y) -> vm.bool.of(x.compareTo(y) < 0)));
        vars.put(vm.sym.handleFor("times"), unaryOp("Num.times", this::timesMethod));
        vars.put(vm.sym.handleFor("up"), unaryOp("Num.up", this::upMethod));
        vars.put(vm.sym.handleFor("down"), unaryOp("Num.down", this::downMethod));
        vars.put(vm.sym.handleFor("repr"), unaryOp("Num.repr", this::reprMethod));
        vars.put(vm.sym.handleFor("show"), vm.fun.make("Num.show")
                .takeMinMax(0, 1).action(this::showMethod));
        this.sharedVars = vm.sharedVars.of(vars);
    }

    // Numer.op_div(Denom) {{{1

    /**
     * Implementation of Numer.op_div(Denom).
     */
    private HostResult opDivMethod(CallContext c, BigDecimal dividend, BigDecimal divisor) {
        if (divisor.signum() == 0) {
            return c.raise(String.format(Locale.ROOT,
                        "Num.op_div: zero division: %s is divided by 0", dividend.toPlainString()));
        }
        return c.call("kink/NUM_DIV", newHandle).args(c.recv(), c.arg(0));
    }

    // }}}1
    // Numer.op_intdiv(Denom) {{{1

    /**
     * Implementation of Numer.op_intdiv(Denom).
     */
    private HostResult opIntDivMethod(HostContext c, BigDecimal numer, BigDecimal denom) {
        if (denom.signum() == 0) {
            return c.raise(String.format(Locale.ROOT,
                        "Num.op_intdiv: zero division: %s is divided by 0",
                        numer.toPlainString()));
        }
        RoundingMode mode = denom.signum() < 0 ? RoundingMode.CEILING : RoundingMode.FLOOR;
        return vm.num.of(numer.divide(denom, 0, mode));
    }

    // }}}1
    // Numer.op_rem(Denom) {{{1

    /**
     * Implementation of Numer.op_rem(Denom).
     */
    private HostResult opRemMethod(HostContext c, BigDecimal numer, BigDecimal denom) {
        if (denom.signum() == 0) {
            return c.raise(String.format(Locale.ROOT,
                        "Num.op_rem: zero division: %s is divided by 0",
                        numer.toPlainString()));
        }
        RoundingMode mode = denom.signum() < 0 ? RoundingMode.CEILING : RoundingMode.FLOOR;
        BigDecimal quot = numer.divide(denom, 0, mode);
        BigDecimal rem = numer.subtract(quot.multiply(denom));
        return vm.num.of(rem);
    }

    // }}}1
    // Num.round {{{1

    /**
     * Implementation of Num.round.
     */
    private HostResult roundMethod(CallContext c, NumVal num) {
        return c.call("kink/NUM_DIV", newHandle).args(c.recv(), vm.num.of(1));
    }

    // }}}1
    // Num.times {{{1

    /**
     * Implementation of Num.times.
     */
    private HostResult timesMethod(HostContext c, BigDecimal count) {
        if (count.scale() != 0 || count.signum() < 0) {
            return c.call(vm.graph.raiseFormat(
                        "Num.times: required nonnegative int as Num, but got {}",
                        vm.graph.of(vm.num.of(count))));
        }
        return c.call("kink/iter/ITER", newHandle).args(timesIfun(BigDecimal.ZERO, count));
    }

    /**
     * Ifun of Num.times.
     */
    private FunVal timesIfun(BigDecimal ind, BigDecimal count) {
        return makeIfun("Num.times-ifun", (c, proc, fin) -> ind.compareTo(count) < 0
                ? c.call(proc).args(vm.num.of(ind), timesIfun(ind.add(BigDecimal.ONE), count))
                : c.call(fin));
    }

    // }}}1
    // Num.up and .down {{{1

    /**
     * Implementation of Num.up.
     */
    private HostResult upMethod(HostContext c, BigDecimal start) {
        return c.call("kink/iter/ITER", newHandle)
            .args(upDownIfun("Num.up-ifun", start, BigDecimal.ONE));
    }

    /**
     * Implementation of Num.down.
     */
    private HostResult downMethod(HostContext c, BigDecimal start) {
        return c.call("kink/iter/ITER", newHandle)
            .args(upDownIfun("Num.down-ifun", start, BigDecimal.valueOf(-1)));
    }

    /**
     * Ifun of Num.up and Num.down.
     */
    private FunVal upDownIfun(String prefix, BigDecimal num, BigDecimal delta) {
        return makeIfun(prefix, (c, proc, fin) ->
                c.call(proc).args(vm.num.of(num), upDownIfun(prefix, num.add(delta), delta)));
    }

    // }}}1
    // Num.repr {{{1

    /**
     * Implementation of Num.repr.
     */
    private HostResult reprMethod(HostContext c, BigDecimal num) {
        return vm.str.of(num.scale() < 0
            ? reprWithScale(num)
            : reprBasedOnPlain(num));
    }

    /**
     * Makes repr representation based on the plain string.
     */
    private String reprBasedOnPlain(BigDecimal num) {
        return num.signum() < 0
            ? String.format(Locale.ROOT, "(%s)", num.toPlainString())
            : num.toPlainString();
    }

    /**
     * Makes repr representation with scale.
     */
    private String reprWithScale(BigDecimal num) {
        BigInteger mantissa = num.unscaledValue();
        int scale = num.scale();
        return String.format(Locale.ROOT, "(num mantissa=%d scale=%d)", mantissa, scale);
    }

    // }}}1
    // Num.show(...[$config]) {{{1

    /**
     * Implementation of Num.show(...[$config]).
     */
    private HostResult showMethod(CallContext c) {
        if (! (c.recv() instanceof NumVal)) {
            return c.raise("Num.show: expected num as \\recv, but got non-num");
        }

        Val config = c.argCount() == 0
            ? vm.fun.make().take(1).action(cc -> vm.nada)
            : c.arg(0);
        NumVal num = (NumVal) c.recv();
        return c.call("kink/NUM", showNumHandle).args(num, config);
    }

    // }}}1

    /**
     * Returns an unary operator fun.
     */
    private FunVal unaryOp(
            String prefix, ThrowingFunction2<CallContext, BigDecimal, HostResult> op) {
        return vm.fun.make(prefix).take(0).action(c -> {
            Val recv = c.recv();
            if (! (recv instanceof NumVal)) {
                return c.call(vm.graph.raiseFormat("{}: required num as \\recv, but got {}",
                            vm.graph.of(vm.str.of(prefix)),
                            vm.graph.repr(recv)));
            }
            return op.apply(c, ((NumVal) recv).bigDecimal());
        });
    }

    /**
     * Returns a binary operator fun.
     */
    private FunVal binaryOp(
            String prefix, ThrowingFunction3<CallContext, BigDecimal, BigDecimal, HostResult> op) {
        return vm.fun.make(prefix).take(1).action(c -> {
            Val recv = c.recv();
            if (! (recv instanceof NumVal)) {
                return c.call(vm.graph.raiseFormat("{}: required num as \\recv, but got {}",
                            vm.graph.of(vm.str.of(prefix)),
                            vm.graph.repr(recv)));
            }

            Val a0 = c.arg(0);
            if (! (a0 instanceof NumVal)) {
                return c.call(vm.graph.raiseFormat("{}: the arg must be a num, but got {}",
                            vm.graph.of(vm.str.of(prefix)),
                            vm.graph.repr(a0)));
            }
            return op.apply(c, ((NumVal) recv).bigDecimal(), ((NumVal) a0).bigDecimal());
        });
    }

    /**
     * Makes an unary int operator fun.
     */
    private FunVal unaryIntOp(
            String prefix, ThrowingFunction2<CallContext, BigInteger, HostResult> op) {
        return vm.fun.make(prefix).take(0).action(c -> {
            Val recv = c.recv();
            BigInteger recvInt = exactBigInt(recv);
            if (recvInt == null) {
                return c.call(vm.graph.raiseFormat("{}: required int num as \\recv, but got {}",
                            vm.graph.of(vm.str.of(prefix)),
                            vm.graph.repr(recv)));
            }
            return op.apply(c, recvInt);
        });
    }

    /**
     * Returns a binary int operator fun.
     */
    private FunVal binaryIntOp(
            String prefix, ThrowingFunction3<CallContext, BigInteger, BigInteger, HostResult> op) {
        return vm.fun.make(prefix).take(1).action(c -> {
            Val recv = c.recv();
            BigInteger recvInt = exactBigInt(recv);
            if (recvInt == null) {
                return c.call(vm.graph.raiseFormat("{}: required int num as \\recv, but got {}",
                            vm.graph.of(vm.str.of(prefix)),
                            vm.graph.repr(recv)));
            }

            Val a0 = c.arg(0);
            BigInteger argInt = exactBigInt(a0);
            if (argInt == null) {
                return c.call(vm.graph.raiseFormat("{}: the arg must be an int num, but got {}",
                            vm.graph.of(vm.str.of(prefix)),
                            vm.graph.repr(a0)));
            }

            return op.apply(c, recvInt, argInt);
        });
    }

    /**
     * Makes an nullary method fun.
     */
    private FunVal method0(
            String prefix, ThrowingFunction2<CallContext, NumVal, HostResult> action) {
        return vm.fun.make(prefix).take(0).action(c -> {
            Val recv = c.recv();
            if (! (recv instanceof NumVal)) {
                return c.call(vm.graph.raiseFormat("{}: required num as \\recv, but got {}",
                            vm.graph.of(vm.str.of(prefix)),
                            vm.graph.repr(recv)));
            }
            return action.apply(c, (NumVal) recv);
        });
    }

    /**
     * Makes a fun for shift methods.
     */
    private FunVal shiftMethod(
            String prefix, ThrowingFunction2<BigInteger, Integer, BigInteger> action) {
        return vm.fun.make(prefix).take(1).action(c -> {
            Val recv = c.recv();
            BigInteger recvInt = exactBigInt(recv);
            if (recvInt == null) {
                return c.call(vm.graph.raiseFormat("{}: required int num as \\recv, but got {}",
                            vm.graph.of(vm.str.of(prefix)),
                            vm.graph.repr(recv)));
            }

            Val shift = c.arg(0);
            OptionalInt shiftInt = NumOperations.getExactInt(shift);
            if (! shiftInt.isPresent()) {
                return c.call(vm.graph.raiseFormat(
                            "{}: Bit_count must be an int num between [{}, {}], but got {}",
                            vm.graph.of(vm.str.of(prefix)),
                            vm.graph.of(vm.num.of(Integer.MIN_VALUE)),
                            vm.graph.of(vm.num.of(Integer.MAX_VALUE)),
                            vm.graph.repr(shift)));
            }
            return vm.num.of(action.apply(recvInt, shiftInt.getAsInt()));
        });
    }

    /**
     * Extracts an exact BigInteger from val, or null.
     */
    private BigInteger exactBigInt(Val val) {
        if (! (val instanceof NumVal)) {
            return null;
        }

        BigDecimal dec = ((NumVal) val).bigDecimal();
        try {
            return dec.toBigIntegerExact();
        } catch (ArithmeticException ex) {
            return null;
        }
    }

    /**
     * Makes an ifun.
     *
     * TODO: this method is not specific to nums, thus is may have better to be moved to FunHelper.
     */
    private FunVal makeIfun(
            String prefix,
            ThrowingFunction3<CallContext, FunVal, FunVal, HostResult> action) {
        return vm.fun.make("Num.times-ifun").take(2).action(c -> {
            Val proc = c.arg(0);
            if (! (proc instanceof FunVal)) {
                return c.call(vm.graph.raiseFormat("{}: required fun as $proc, but got {}",
                            vm.graph.of(vm.str.of(prefix)),
                            vm.graph.repr(proc)));
            }

            Val fin = c.arg(1);
            if (! (fin instanceof FunVal)) {
                return c.call(vm.graph.raiseFormat("{}: required fun as $fin, but got {}",
                            vm.graph.of(vm.str.of(prefix)),
                            vm.graph.repr(fin)));
            }
            return action.apply(c, (FunVal) proc, (FunVal) fin);
        });
    }

}

// vim: et sw=4 sts=4 fdm=marker
