package org.kink_lang.kink;

import java.math.BigDecimal;
import java.math.BigInteger;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import org.kink_lang.kink.internal.function.ThrowingFunction3;
import org.kink_lang.kink.internal.function.ThrowingFunction4;
import org.kink_lang.kink.hostfun.HostContext;
import org.kink_lang.kink.hostfun.CallContext;
import org.kink_lang.kink.hostfun.HostResult;
import org.kink_lang.kink.internal.num.NumOperations;

/**
 * The helper of str vals.
 */
public class StrHelper {

    /** The vm. */
    private final Vm vm;

    /** The sym handle of "format". */
    private int formatHandle;

    /** The sym handle of "args". */
    private int argHandle;

    /** The max length of strs. Modifiable by tests. */
    int maxLength = Integer.MAX_VALUE;

    /** Shared vars of str vals. */
    SharedVars sharedVars;

    /**
     * Constructs the helper.
     */
    StrHelper(Vm vm) {
        this.vm = vm;
    }

    /**
     * Returns a str val containing the string.
     *
     * @param string the string.
     * @return a str val.
     */
    public StrVal of(String string) {
        return new StrVal(vm, string);
    }

    /**
     * Returns the max length of strs.
     */
    int getMaxLength() {
        return this.maxLength;
    }

    /**
     * Initializes the helper.
     */
    void init() {
        this.formatHandle = vm.sym.handleFor("format");
        this.argHandle = vm.sym.handleFor("arg");

        Map<Integer, Val> vars = new HashMap<>();
        vars.put(vm.sym.handleFor("op_add"), method1("Str.op_add(Arg_str)", this::opAddMethod));
        vars.put(vm.sym.handleFor("runes"), unaryOp("Str.runes", this::runesMethod));
        vars.put(vm.sym.handleFor("empty?"),
                unaryOp("Str.empty?", s -> vm.bool.of(s.isEmpty())));
        vars.put(vm.sym.handleFor("size"), unaryOp("Str.size", this::sizeMethod));
        vars.put(vm.sym.handleFor("get"), method1("Str.get(Rune_index)", this::getMethod));
        vars.put(vm.sym.handleFor("slice"),
                method2("Str.slice(From_pos To_pos)", this::sliceMethod));
        vars.put(vm.sym.handleFor("take_front"), takeDropMethod("Str.take_front(N)", 0, 0, 1, 0));
        vars.put(vm.sym.handleFor("take_back"), takeDropMethod("Str.take_back(N)", -1, 1, 0, 1));
        vars.put(vm.sym.handleFor("drop_front"), takeDropMethod("Str.drop_front(N)", 1, 0, 0, 1));
        vars.put(vm.sym.handleFor("drop_back"), takeDropMethod("Str.drop_back(N)", 0, 0, -1, 1));
        vars.put(vm.sym.handleFor("search_slice"),
                method2("Str.search_slice(Min_ind Slice)", this::searchSliceMethod));
        vars.put(vm.sym.handleFor("have_prefix?"),
                binaryOp("Str.have_prefix?", (text, pat) -> vm.bool.of(text.startsWith(pat))));
        vars.put(vm.sym.handleFor("have_suffix?"),
                binaryOp("Str.have_suffix?", (text, pat) -> vm.bool.of(text.endsWith(pat))));
        vars.put(vm.sym.handleFor("have_slice?"),
                binaryOp("Str.have_slice?(Slice)", (text, pat) -> vm.bool.of(text.contains(pat))));
        vars.put(vm.sym.handleFor("trim"), unaryOp("Str.trim", s -> vm.str.of(s.trim())));
        vars.put(vm.sym.handleFor("ascii_upcase"),
                unaryOp("Str.ascii_upcase", this::asciiUpcaseMethod));
        vars.put(vm.sym.handleFor("ascii_downcase"),
                unaryOp("Str.ascii_downcase", this::asciiDowncaseMethod));
        vars.put(vm.sym.handleFor("format"),
                vm.fun.make("Str.format").action(this::formatMethod));
        vars.put(vm.sym.handleFor("op_mul"),
                method1("Str.op_mul(Count)", this::opMulMethod));
        vars.put(vm.sym.handleFor("op_lt"), binaryOp("Str.op_lt",
                    (x, y) -> vm.bool.of(lessThanByRunes(x, y))));
        vars.put(vm.sym.handleFor("op_eq"), binaryOp("Str.op_eq",
                    (x, y) -> vm.bool.of(x.equals(y))));
        vars.put(vm.sym.handleFor("repr"), unaryOp("Str.repr", this::reprMethod));
        vars.put(vm.sym.handleFor("show"), vm.fun.make("Str.show")
                .takeMinMax(0, 1).action(this::showMethod));
        this.sharedVars = vm.sharedVars.of(vars);
    }

    // Str.op_add(Arg_str) {{{1

    /**
     * Implementation of Str.op_add(Arg_str).
     */
    private HostResult opAddMethod(HostContext c, StrVal str, Val argStrVal) {
        if (! (argStrVal instanceof StrVal)) {
            return c.call(vm.graph.raiseFormat(
                        "Str.op_add(Arg_str): required str as Arg_str, but got {}",
                        vm.graph.repr(argStrVal)));
        }
        StrVal argStr = (StrVal) argStrVal;

        if ((long) str.length() + argStr.length() > vm.str.getMaxLength()) {
            return c.raise("Str.op_add(Arg_str): too long result");
        }

        return str.concat(argStr);
    }

    // }}}1
    // Str.runes {{{1

    /**
     * Implementation of Str.runes.
     */
    private HostResult runesMethod(String str) {
        List<NumVal> runes = str.codePoints()
            .mapToObj(rune -> vm.num.of(rune))
            .collect(Collectors.toList());
        return vm.vec.of(runes);
    }

    // }}}1
    // Str.size {{{1

    /**
     * Implementation of Str.size.
     */
    private HostResult sizeMethod(String str) {
        return vm.num.of(str.codePointCount(0, str.length()));
    }

    // }}}1
    // Str.get(Rune_index) {{{1

    /**
     * Implementation of Str.get(Rune_index).
     */
    private HostResult getMethod(HostContext c, StrVal strVal, Val runeIndexVal) {
        String str = strVal.string();
        int runeCount = str.codePointCount(0, str.length());
        int runeIndex = NumOperations.getElemIndex(runeIndexVal, runeCount);
        if (runeIndex < 0) {
            return c.call(vm.graph.raiseFormat(
                        "Str.get(Rune_index): Rune_index must be an int num in [0, {}), but got {}",
                        vm.graph.of(vm.num.of(runeCount)),
                        vm.graph.repr(runeIndexVal)));
        }
        int index = str.offsetByCodePoints(0, runeIndex);
        int rune = str.codePointAt(index);
        return vm.num.of(rune);
    }

    // }}}1
    // Str.slice(From To) {{{1

    /**
     * Implementation of Str.slice(From To).
     */
    private HostResult sliceMethod(HostContext c, String str, Val fromVal, Val toVal) {
        String desc = "Str.slice(From_pos To_pos)";
        if (! (fromVal instanceof NumVal)) {
            return c.call(vm.graph.raiseFormat("{}: From_pos must be a num, but got {}",
                        vm.graph.of(vm.str.of(desc)),
                        vm.graph.repr(fromVal)));
        }
        BigDecimal fromDec = ((NumVal) fromVal).bigDecimal();
        if (! (toVal instanceof NumVal)) {
            return c.call(vm.graph.raiseFormat("{}: To_pos must be a num, but got {}",
                        vm.graph.of(vm.str.of(desc)),
                        vm.graph.repr(toVal)));
        }
        BigDecimal toDec = ((NumVal) toVal).bigDecimal();
        int runeCount = str.codePointCount(0, str.length());
        if (! NumOperations.isRangePair(fromDec, toDec, runeCount)) {
            return c.call(vm.graph.raiseFormat(
                        "{}: required index pair in [0, {}], but got {} and {}",
                        vm.graph.of(vm.str.of(desc)),
                        vm.graph.of(vm.num.of(runeCount)),
                        vm.graph.repr(fromVal),
                        vm.graph.repr(toVal)));
        }
        int fromRuneInd = fromDec.intValueExact();
        int toRuneInd = toDec.intValueExact();
        int fromInd = str.offsetByCodePoints(0, fromRuneInd);
        int toInd = str.offsetByCodePoints(0, toRuneInd);
        return vm.str.of(str.substring(fromInd, toInd));
    }

    // }}}1
    // Str.take_front, take_back, drop_front, drop_back {{{1

    /**
     * Implementation of Str.take_front, take_back, drop_front, drop_back.
     */
    private FunVal takeDropMethod(String prefix,
            int fromNumCoef, int fromSizeCoef, int toNumCoef, int toSizeCoef) {
        return method1(prefix, (c, strVal, numVal) -> {
            String str = strVal.string();
            int runeCount = str.codePointCount(0, str.length());
            int num = NumOperations.getPosIndex(numVal, runeCount);
            if (num < 0) {
                return c.call(vm.graph.raiseFormat(
                            "{}: N must be an int num in [0, {}], but got {}",
                            vm.graph.of(vm.str.of(prefix)),
                            vm.graph.of(vm.num.of(runeCount)),
                            vm.graph.repr(numVal)));
            }
            int fromRuneInd = fromNumCoef * num + fromSizeCoef * runeCount;
            int toRuneInd = toNumCoef * num + toSizeCoef * runeCount;
            int fromCharInd = str.offsetByCodePoints(0, fromRuneInd);
            int toCharInd = str.offsetByCodePoints(0, toRuneInd);
            return vm.str.of(str.substring(fromCharInd, toCharInd));
        });
    }

    // }}}1
    // Str.search_slice(Min_ind Slice) {{{1

    /**
     * Implementation of Str.search_slice.
     */
    private HostResult searchSliceMethod(
            CallContext c, String str, Val minIndVal, Val sliceStrVal) {
        String prefix = "Str.search_slice(Min_ind Slice)";

        if (! (sliceStrVal instanceof StrVal)) {
            return c.call(vm.graph.raiseFormat(
                        "{}: required str for Slice, but got {}",
                        vm.graph.of(vm.str.of(prefix)),
                        vm.graph.repr(sliceStrVal)));
        }
        String slice = ((StrVal) sliceStrVal).string();

        int runeCount = str.codePointCount(0, str.length());
        int minRuneInd = NumOperations.getPosIndex(minIndVal, runeCount);
        if (minRuneInd < 0) {
            return c.call(vm.graph.raiseFormat(
                        "{}: required int num in [0, {}] for Min_ind, but got {}",
                        vm.graph.of(vm.str.of(prefix)),
                        vm.graph.of(vm.num.of(runeCount)),
                        vm.graph.repr(minIndVal)));
        }
        int minCharInd = str.offsetByCodePoints(0, minRuneInd);
        int charInd = str.indexOf(slice, minCharInd);
        if (charInd < 0) {
            return vm.vec.of();
        }
        int runeInd = str.codePointCount(0, charInd);
        return vm.vec.of(vm.num.of(runeInd));
    }

    // }}}1
    // Str.ascii_upcase {{{

    /**
     * Implementation of Str.ascii_upcase.
     */
    private HostResult asciiUpcaseMethod(String src) {
        StringBuilder sb = new StringBuilder(src.length());
        for (int i = 0; i < src.length(); ++ i) {
            char ch = src.charAt(i);
            char dest = 'a' <= ch && ch <= 'z'
                ? (char) (ch + ('A' - 'a'))
                : ch;
            sb.append(dest);
        }
        return vm.str.of(sb.toString());
    }

    // }}}
    // Str.ascii_downcase {{{

    /**
     * Implementation of Str.ascii_downcase.
     */
    private HostResult asciiDowncaseMethod(String src) {
        StringBuilder sb = new StringBuilder(src.length());
        for (int i = 0; i < src.length(); ++ i) {
            char ch = src.charAt(i);
            char dest = 'A' <= ch && ch <= 'Z'
                ? (char) (ch + ('a' - 'A'))
                : ch;
            sb.append(dest);
        }
        return vm.str.of(sb.toString());
    }

    // }}}
    // Str.format(...Args) {{{1

    /**
     * Implementation of Str.format(...Args).
     */
    private HostResult formatMethod(CallContext c) {
        Val template = c.recv();
        if (! (template instanceof StrVal)) {
            return c.call(vm.graph.raiseFormat(
                        "Str.format(,,,): required a str for Str, but got {}",
                        vm.graph.repr(template)));
        }

        if (c.argCount() == 1 && c.arg(0) instanceof FunVal) {
            FunVal config = (FunVal) c.arg(0);
            return c.call("kink/_str/FORMAT", formatHandle).args(template, config);
        } else {
            List<Val> args = IntStream.range(0, c.argCount())
                .mapToObj(i -> c.arg(i))
                .collect(Collectors.toList());
            FunVal config = vm.fun.make().take(1).action(
                    cc -> formatMethodConfigAux(cc, cc.arg(0), args));
            return c.call("kink/_str/FORMAT", formatHandle).args(template, config);
        }
    }

    /**
     * Implementation of config fun for Str.format.
     */
    private HostResult formatMethodConfigAux(HostContext c, Val conf, List<Val> args) {
        if (args.isEmpty()) {
            return vm.nada;
        }

        return c.call(conf, this.argHandle).args(args.get(0))
            .on((cc, r) -> formatMethodConfigAux(cc, conf, args.subList(1, args.size())));
    }

    // }}}1
    // Str.op_mul(Count) {{{1

    /**
     * Implementation of Str.op_mul(Count).
     */
    private HostResult opMulMethod(HostContext c, StrVal strVal, Val countVal) {
        String str = strVal.string();
        BigInteger count = NumOperations.getExactBigInteger(countVal);
        if (count == null || count.signum() < 0) {
            return c.call(vm.graph.raiseFormat(
                        "Str.op_mul(Count): required int num >=0 for Count, but got {}",
                        vm.graph.repr(countVal)));
        }
        if (str.isEmpty()) {
            return vm.str.of("");
        }
        BigInteger resultLen = count.multiply(BigInteger.valueOf(str.length()));
        if (resultLen.compareTo(BigInteger.valueOf(vm.str.getMaxLength())) > 0) {
            return c.raise("Str.op_mul(Count): too long result");
        }
        StringBuilder sb = new StringBuilder();
        int countInt = count.intValueExact();
        for (int i = 0; i < countInt; ++ i) {
            sb.append(str);
        }
        return vm.str.of(sb.toString());
    }

    // }}}1
    // Str.op_lt {{{1

    /**
     * Compares strings by runes.
     */
    private boolean lessThanByRunes(String x, String y) {
        int xi = 0;
        int yi = 0;
        while (true) {
            if (yi >= y.length()) {
                return false;
            }
            if (xi >= x.length()) {
                return true;
            }
            int xrune = x.codePointAt(xi);
            int yrune = y.codePointAt(yi);
            if (xrune < yrune) {
                return true;
            } else if (xrune > yrune) {
                return false;
            }
            xi = x.offsetByCodePoints(xi, 1);
            yi = y.offsetByCodePoints(yi, 1);
        }
    }

    // }}}1
    // Str.repr {{{1

    /** Mapping from special characters to their representations in rich string literals. */
    private static final Map<Character, String> CHAR_TO_REPR;

    static {
        Map<Character, String> map = new HashMap<>();
        map.put('\u0000', "\\0");
        map.put('\u0007', "\\a");
        map.put('\b', "\\b");
        map.put('\t', "\\t");
        map.put('\n', "\\n");
        map.put('\u000b', "\\v");
        map.put('\f', "\\f");
        map.put('\r', "\\r");
        map.put('\u001b', "\\e");
        map.put('"', "\\\"");
        map.put('\\', "\\\\");
        map.put('\u007f', "\\x{00007f}");
        for (char ch = '\u0001'; ch <= '\u001f'; ++ ch) {
            map.putIfAbsent(ch, String.format(Locale.ROOT, "\\x{%06x}", (int) ch));
        }
        CHAR_TO_REPR = Collections.unmodifiableMap(map);
    }

    /**
     * Implementation of Str.repr.
     */
    StrVal reprMethod(String src) {
        StringBuilder sb = new StringBuilder("\"");
        for (char ch : src.toCharArray()) {
            String repr = CHAR_TO_REPR.getOrDefault(ch, Character.toString(ch));
            sb.append(repr);
        }
        sb.append('"');
        return vm.str.of(sb.toString());
    }

    // }}}1
    // Str.show(...[$config]) {{{1

    /**
     * Implementation of Str.show.
     */
    private HostResult showMethod(CallContext c) {
        return c.recv();
    }

    // }}}1

    /**
     * Unary str operator.
     */
    private FunVal unaryOp(String prefix, Function<String, HostResult> action) {
        return vm.fun.make(prefix).take(0).action(c -> {
            if (! (c.recv() instanceof StrVal)) {
                return c.call(vm.graph.raiseFormat("{}: required str as \\recv, but got {}",
                            vm.graph.of(vm.str.of(prefix)),
                            vm.graph.repr(c.recv())));
            }
            String recv = ((StrVal) c.recv()).string();
            return action.apply(recv);
        });
    }

    /**
     * Binary str operator.
     */
    private FunVal binaryOp(String prefix, BiFunction<String, String, HostResult> action) {
        return vm.fun.make(prefix).take(1).action(c -> {
            if (! (c.recv() instanceof StrVal)) {
                return c.call(vm.graph.raiseFormat("{}: required str as \\recv, but got {}",
                            vm.graph.of(vm.str.of(prefix)),
                            vm.graph.repr(c.recv())));
            }
            String recv = ((StrVal) c.recv()).string();

            Val argVal = c.arg(0);
            if (! (argVal instanceof StrVal)) {
                return c.call(vm.graph.raiseFormat("{}: the arg must be a str, but got {}",
                            vm.graph.of(vm.str.of(prefix)),
                            vm.graph.repr(argVal)));
            }
            String arg = ((StrVal) argVal).string();

            return action.apply(recv, arg);
        });
    }

    /**
     * Makes an unary method fun.
     */
    private FunVal method1(
            String prefix, ThrowingFunction3<CallContext, StrVal, Val, HostResult> action) {
        return vm.fun.make(prefix).take(1).action(c -> {
            Val recv = c.recv();
            if (! (recv instanceof StrVal)) {
                return c.call(vm.graph.raiseFormat("{}: required str as \\recv, but got {}",
                            vm.graph.of(vm.str.of(prefix)),
                            vm.graph.repr(recv)));
            }
            return action.apply(c, (StrVal) recv, c.arg(0));
        });
    }

    /**
     * Makes a binary method fun.
     */
    private FunVal method2(
            String prefix, ThrowingFunction4<CallContext, String, Val, Val, HostResult> action) {
        return vm.fun.make(prefix).take(2).action(c -> {
            Val recv = c.recv();
            if (! (recv instanceof StrVal)) {
                return c.call(vm.graph.raiseFormat("{}: required str as \\recv, but got {}",
                            vm.graph.of(vm.str.of(prefix)),
                            vm.graph.repr(recv)));
            }
            return action.apply(c, ((StrVal) recv).string(), c.arg(0), c.arg(1));
        });
    }

}

// vim: et sw=4 sts=4 fdm=marker
