package org.kink_lang.kink;

import java.math.BigDecimal;
import java.math.BigInteger;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.function.UnaryOperator;
import java.util.stream.IntStream;

import org.kink_lang.kink.internal.function.ThrowingFunction3;
import org.kink_lang.kink.hostfun.CallContext;
import org.kink_lang.kink.hostfun.HostContext;
import org.kink_lang.kink.hostfun.HostFunBuilder;
import org.kink_lang.kink.hostfun.HostResult;
import org.kink_lang.kink.internal.num.NumOperations;

/**
 * The helper of str vals.
 */
public class StrHelper {

    /** The vm. */
    private final Vm vm;

    /** The sym handle of "format". */
    private int formatHandle;

    /** The sym handle of "args". */
    private int argHandle;

    /** The max length of strs. Modifiable by tests. */
    int maxLength = Integer.MAX_VALUE;

    /** Shared vars of str vals. */
    SharedVars sharedVars;

    /**
     * Constructs the helper.
     */
    StrHelper(Vm vm) {
        this.vm = vm;
    }

    /**
     * Returns a str val containing the string.
     *
     * @param string the string.
     * @return a str val.
     */
    public StrVal of(String string) {
        return new StrVal(vm, string);
    }

    /**
     * Returns the max length of strs.
     */
    int getMaxLength() {
        return this.maxLength;
    }

    /**
     * Initializes the helper.
     */
    void init() {
        this.formatHandle = vm.sym.handleFor("format");
        this.argHandle = vm.sym.handleFor("arg");

        Map<Integer, Val> vars = new HashMap<>();
        addMethod(vars, "Str1", "op_add", "(Str2)", 1, this::opAddMethod);
        addMethod(vars, "Str", "runes", "", 0, this::runesMethod);
        addMethod(vars, "Str", "empty?", "", 0, (c, s, d) -> vm.bool.of(s.string().isEmpty()));
        addMethod(vars, "Str", "size", "", 0, (c, s, d) -> vm.num.of(s.runeCount()));
        addMethod(vars, "Str", "get", "(Rune_index)", 1, this::getMethod);
        addMethod(vars, "Str", "slice", "(From_pos To_pos)", 2, this::sliceMethod);
        addTakeDropMethod(vars, "take_front", 0, 0, 1, 0);
        addTakeDropMethod(vars, "take_back", -1, 1, 0, 1);
        addTakeDropMethod(vars, "drop_front", 1, 0, 0, 1);
        addTakeDropMethod(vars, "drop_back", 0, 0, -1, 1);
        addMethod(vars, "Str", "search_slice", "(Min_ind Slice)", 2, this::searchSliceMethod);
        addBinaryOp(vars, "Str", "have_prefix?", "Prefix",
                (c, text, pat) -> vm.bool.of(text.string().startsWith(pat.string())));
        addBinaryOp(vars, "Str", "have_suffix?", "Suffix",
                (c, text, pat) -> vm.bool.of(text.string().endsWith(pat.string())));
        addBinaryOp(vars, "Str", "have_slice?", "Slice",
                (c, text, pat) -> vm.bool.of(text.string().contains(pat.string())));
        addMethod(vars, "Str", "trim", "", 0, (c, s, d) -> vm.str.of(s.string().trim()));
        addMethod(vars, "Str", "ascii_upcase", "", 0, this::asciiUpcaseMethod);
        addMethod(vars, "Str", "ascii_downcase", "", 0, this::asciiDowncaseMethod);
        addMethod(vars, "Str", "format", "(,,,)", b -> b, this::formatMethod);
        addMethod(vars, "Str", "op_mul", "(Count)", 1, this::opMulMethod);
        addBinaryOp(vars, "Str1", "op_lt", "Str2",
                (c, x, y) -> vm.bool.of(lessThanByRunes(x.string(), y.string())));
        addBinaryOp(vars, "Str1", "op_eq", "Str2",
                (c, x, y) -> vm.bool.of(x.string().equals(y.string())));
        addMethod(vars, "Str", "repr", "", 0, this::reprMethod);
        addMethod(vars, "Str", "show", "(...[$config])", b -> b.takeMinMax(0, 1), (c, s, d) -> s);
        this.sharedVars = vm.sharedVars.of(vars);
    }

    // Str1.op_add(Str2) {{{1

    /**
     * Implementation of Str1.op_add(Str2).
     */
    private HostResult opAddMethod(CallContext c, StrVal str, String desc) {
        if (! (c.arg(0) instanceof StrVal argStr)) {
            return c.call(vm.graph.raiseFormat(
                        "{}: required str as Str2, but got {}",
                        vm.graph.of(vm.str.of(desc)),
                        vm.graph.repr(c.arg(0))));
        }

        if ((long) str.length() + argStr.length() > vm.str.getMaxLength()) {
            return c.raise(desc + ": too long result");
        }

        return str.concat(argStr);
    }

    // }}}1

    // Str.runes {{{1

    /**
     * Implementation of Str.runes.
     */
    private HostResult runesMethod(CallContext c, StrVal strVal, String desc) {
        List<NumVal> runes = strVal.string().codePoints()
            .mapToObj(rune -> vm.num.of(rune))
            .toList();
        return vm.vec.of(runes);
    }

    // }}}1

    // Str.get(Rune_index) {{{1

    /**
     * Implementation of Str.get(Rune_index).
     */
    private HostResult getMethod(CallContext c, StrVal strVal, String desc) {
        String str = strVal.string();
        Val runeIndexVal = c.arg(0);
        int runeIndex = NumOperations.getElemIndex(runeIndexVal, strVal.runeCount());
        if (runeIndex < 0) {
            return c.call(vm.graph.raiseFormat(
                        "{}: Rune_index must be int num in [0, {}), but got {}",
                        vm.graph.of(vm.str.of(desc)),
                        vm.graph.of(vm.num.of(strVal.runeCount())),
                        vm.graph.repr(runeIndexVal)));
        }
        int index = strVal.runePosToCharPos(runeIndex);
        int rune = str.codePointAt(index);
        return vm.num.of(rune);
    }

    // }}}1

    // Str.slice(From To) {{{1

    /**
     * Implementation of Str.slice(From To).
     */
    private HostResult sliceMethod(CallContext c, StrVal strVal, String desc) {
        if (! (c.arg(0) instanceof NumVal fromNum)) {
            return c.call(vm.graph.raiseFormat("{}: From_pos must be a num, but got {}",
                        vm.graph.of(vm.str.of(desc)),
                        vm.graph.repr(c.arg(0))));
        }
        BigDecimal fromDec = fromNum.bigDecimal();
        if (! (c.arg(1) instanceof NumVal toNum)) {
            return c.call(vm.graph.raiseFormat("{}: To_pos must be a num, but got {}",
                        vm.graph.of(vm.str.of(desc)),
                        vm.graph.repr(c.arg(1))));
        }
        BigDecimal toDec = toNum.bigDecimal();
        String str = strVal.string();
        if (! NumOperations.isRangePair(fromDec, toDec, strVal.runeCount())) {
            return c.call(vm.graph.raiseFormat(
                        "{}: required index pair in [0, {}], but got {} and {}",
                        vm.graph.of(vm.str.of(desc)),
                        vm.graph.of(vm.num.of(strVal.runeCount())),
                        vm.graph.repr(fromNum),
                        vm.graph.repr(toNum)));
        }
        int fromRuneInd = fromDec.intValueExact();
        int toRuneInd = toDec.intValueExact();
        int fromInd = strVal.runePosToCharPos(fromRuneInd);
        int toInd = strVal.runePosToCharPos(toRuneInd);
        return vm.str.of(str.substring(fromInd, toInd));
    }

    // }}}1

    // Str.take_front, take_back, drop_front, drop_back {{{1

    /**
     * Add take_xxx or drop_xxx method to vars.
     */
    private void addTakeDropMethod(
            Map<Integer, Val> vars,
            String sym,
            int fromNumCoef, int fromSizeCoef, int toNumCoef, int toSizeCoef) {
        addMethod(vars, "Str", sym, "(N)", 1, (c, strVal, desc) -> {
            String str = strVal.string();
            int num = NumOperations.getPosIndex(c.arg(0), strVal.runeCount());
            if (num < 0) {
                return c.call(vm.graph.raiseFormat(
                            "{}: N must be int num in [0, {}], but got {}",
                            vm.graph.of(vm.str.of(desc)),
                            vm.graph.of(vm.num.of(strVal.runeCount())),
                            vm.graph.repr(c.arg(0))));
            }
            int fromRuneInd = fromNumCoef * num + fromSizeCoef * strVal.runeCount();
            int toRuneInd = toNumCoef * num + toSizeCoef * strVal.runeCount();
            int fromCharInd = strVal.runePosToCharPos(fromRuneInd);
            int toCharInd = strVal.runePosToCharPos(toRuneInd);
            return vm.str.of(str.substring(fromCharInd, toCharInd));
        });
    }

    // }}}1

    // Str.search_slice(Min_ind Slice) {{{1

    /**
     * Implementation of Str.search_slice(Min_ind Slice).
     */
    private HostResult searchSliceMethod(CallContext c, StrVal strVal, String desc) {
        String str = strVal.string();
        int minRuneInd = NumOperations.getPosIndex(c.arg(0), strVal.runeCount());
        if (minRuneInd < 0) {
            return c.call(vm.graph.raiseFormat(
                        "{}: Min_ind must be int num in [0, {}], but got {}",
                        vm.graph.of(vm.str.of(desc)),
                        vm.graph.of(vm.num.of(strVal.runeCount())),
                        vm.graph.repr(c.arg(0))));
        }

        if (! (c.arg(1) instanceof StrVal slice)) {
            return c.call(vm.graph.raiseFormat(
                        "{}: Slice must be str, but got {}",
                        vm.graph.of(vm.str.of(desc)),
                        vm.graph.repr(c.arg(1))));
        }

        int minCharInd = strVal.runePosToCharPos(minRuneInd);
        int charInd = str.indexOf(slice.string(), minCharInd);
        if (charInd < 0) {
            return vm.vec.of();
        }
        int runeInd = str.codePointCount(0, charInd);
        return vm.vec.of(vm.num.of(runeInd));
    }

    // }}}1

    // Str.ascii_upcase {{{

    /**
     * Implementation of Str.ascii_upcase.
     */
    private HostResult asciiUpcaseMethod(CallContext c, StrVal strVal, String desc) {
        var src = strVal.string();
        StringBuilder sb = new StringBuilder(src.length());
        for (int i = 0; i < src.length(); ++ i) {
            char ch = src.charAt(i);
            char dest = 'a' <= ch && ch <= 'z'
                ? (char) (ch + ('A' - 'a'))
                : ch;
            sb.append(dest);
        }
        return vm.str.of(sb.toString());
    }

    // }}}

    // Str.ascii_downcase {{{

    /**
     * Implementation of Str.ascii_downcase.
     */
    private HostResult asciiDowncaseMethod(CallContext c, StrVal strVal, String desc) {
        var src = strVal.string();
        StringBuilder sb = new StringBuilder(src.length());
        for (int i = 0; i < src.length(); ++ i) {
            char ch = src.charAt(i);
            char dest = 'A' <= ch && ch <= 'Z'
                ? (char) (ch + ('a' - 'A'))
                : ch;
            sb.append(dest);
        }
        return vm.str.of(sb.toString());
    }

    // }}}

    // Str.format(...Args) {{{1

    /**
     * Implementation of Str.format(...Args).
     */
    private HostResult formatMethod(CallContext c, StrVal template, String desc) {
        if (c.argCount() == 1 && c.arg(0) instanceof FunVal) {
            FunVal config = (FunVal) c.arg(0);
            return c.call("kink/_str/FORMAT", formatHandle).args(template, config);
        } else {
            List<Val> args = IntStream.range(0, c.argCount())
                .mapToObj(i -> c.arg(i))
                .toList();
            FunVal config = vm.fun.make().take(1).action(
                    cc -> formatMethodConfigAux(cc, cc.arg(0), args));
            return c.call("kink/_str/FORMAT", formatHandle).args(template, config);
        }
    }

    /**
     * Implementation of config fun for Str.format.
     */
    private HostResult formatMethodConfigAux(HostContext c, Val conf, List<Val> args) {
        if (args.isEmpty()) {
            return vm.nada;
        }

        return c.call(conf, this.argHandle).args(args.get(0))
            .on((cc, r) -> formatMethodConfigAux(cc, conf, args.subList(1, args.size())));
    }

    // }}}1

    // Str.op_mul(Count) {{{1

    /**
     * Implementation of Str.op_mul(Count).
     */
    private HostResult opMulMethod(CallContext c, StrVal strVal, String desc) {
        String str = strVal.string();
        BigInteger count = NumOperations.getExactBigInteger(c.arg(0));
        if (count == null || count.signum() < 0) {
            return c.call(vm.graph.raiseFormat(
                        "Str.op_mul(Count): required int num >=0 for Count, but got {}",
                        vm.graph.repr(c.arg(0))));
        }
        if (str.isEmpty()) {
            return vm.str.of("");
        }
        BigInteger resultLen = count.multiply(BigInteger.valueOf(str.length()));
        if (resultLen.compareTo(BigInteger.valueOf(vm.str.getMaxLength())) > 0) {
            return c.raise("Str.op_mul(Count): too long result");
        }
        StringBuilder sb = new StringBuilder();
        int countInt = count.intValueExact();
        for (int i = 0; i < countInt; ++ i) {
            sb.append(str);
        }
        return vm.str.of(sb.toString());
    }

    // }}}1

    // Str.op_lt {{{1

    /**
     * Compares strings by runes.
     */
    private boolean lessThanByRunes(String x, String y) {
        int xi = 0;
        int yi = 0;
        while (true) {
            if (yi >= y.length()) {
                return false;
            }
            if (xi >= x.length()) {
                return true;
            }
            int xrune = x.codePointAt(xi);
            int yrune = y.codePointAt(yi);
            if (xrune < yrune) {
                return true;
            } else if (xrune > yrune) {
                return false;
            }
            xi = x.offsetByCodePoints(xi, 1);
            yi = y.offsetByCodePoints(yi, 1);
        }
    }

    // }}}1

    // Str.repr {{{1

    /** Mapping from special characters to their representations in rich string literals. */
    private static final Map<Character, String> CHAR_TO_REPR;

    static {
        Map<Character, String> map = new HashMap<>();
        map.put('\u0000', "\\0");
        map.put('\u0007', "\\a");
        map.put('\b', "\\b");
        map.put('\t', "\\t");
        map.put('\n', "\\n");
        map.put('\u000b', "\\v");
        map.put('\f', "\\f");
        map.put('\r', "\\r");
        map.put('\u001b', "\\e");
        map.put('"', "\\\"");
        map.put('\\', "\\\\");
        map.put('\u007f', "\\x{00007f}");
        for (char ch = '\u0001'; ch <= '\u001f'; ++ ch) {
            map.putIfAbsent(ch, String.format(Locale.ROOT, "\\x{%06x}", (int) ch));
        }
        CHAR_TO_REPR = Collections.unmodifiableMap(map);
    }

    /**
     * Implementation of Str.repr.
     */
    private StrVal reprMethod(CallContext c, StrVal strVal, String desc) {
        return reprMethodImpl(strVal.string());
    }

    /**
     * Open to test.
     */
    StrVal reprMethodImpl(String str) {
        StringBuilder sb = new StringBuilder("\"");
        for (char ch : str.toCharArray()) {
            String repr = CHAR_TO_REPR.getOrDefault(ch, Character.toString(ch));
            sb.append(repr);
        }
        sb.append('"');
        return vm.str.of(sb.toString());
    }

    // }}}1

    /**
     * Add a method to vars, with fixed arity.
     */
    private void addMethod(
            Map<Integer, Val> vars,
            String recvDesc,
            String sym,
            String argsDesc,
            int arity,
            ThrowingFunction3<CallContext, StrVal, String, HostResult> action) {
        addMethod(vars, recvDesc, sym, argsDesc, b -> b.take(arity), action);
    }

    /**
     * Add a method to vars.
     */
    private void addMethod(
            Map<Integer, Val> vars,
            String recvDesc,
            String sym,
            String argsDesc,
            UnaryOperator<HostFunBuilder> limitArity,
            ThrowingFunction3<CallContext, StrVal, String, HostResult> action) {
        String desc = String.format(Locale.ROOT, "%s.%s%s", recvDesc, sym, argsDesc);
        int handle = vm.sym.handleFor(sym);
        var fun = limitArity.apply(vm.fun.make(desc)).action(c -> {
            if (! (c.recv() instanceof StrVal recvStr)) {
                return c.call(vm.graph.raiseFormat("{}: {} must be str, but got {}",
                            vm.graph.of(vm.str.of(desc)),
                            vm.graph.of(vm.str.of(recvDesc)),
                            vm.graph.repr(c.recv())));
            }
            return action.apply(c, recvStr, desc);
        });
        vars.put(handle, fun);
    }

    /**
     * Add a binary op method to vars.
     */
    private void addBinaryOp(
            Map<Integer, Val> vars,
            String recvDesc,
            String sym,
            String argDesc,
            ThrowingFunction3<CallContext, StrVal, StrVal, HostResult> action) {
        addMethod(vars, recvDesc, sym, "(" + argDesc + ")", 1, (c, recv, desc) -> {
            if (! (c.arg(0) instanceof StrVal argStr)) {
                return c.call(vm.graph.raiseFormat("{}: {} must be str, but got {}",
                            vm.graph.of(vm.str.of(desc)),
                            vm.graph.of(vm.str.of(argDesc)),
                            vm.graph.repr(c.arg(0))));
            }
            return action.apply(c, recv, argStr);
        });
    }

}

// vim: et sw=4 sts=4 fdm=marker
