package org.kink_lang.kink.internal.program.ast;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.function.Predicate;
import java.util.function.Function;
import java.util.function.BiFunction;

import javax.annotation.Nullable;

import org.kink_lang.kink.internal.program.lex.NumToken;
import org.kink_lang.kink.internal.program.lex.EotToken;
import org.kink_lang.kink.internal.program.lex.ErrorToken;
import org.kink_lang.kink.internal.program.lex.MarkToken;
import org.kink_lang.kink.internal.program.lex.NounToken;
import org.kink_lang.kink.internal.program.lex.StrToken;
import org.kink_lang.kink.internal.program.lex.VerbToken;
import org.kink_lang.kink.internal.program.lex.Token;

/**
 * A single run of parsing.
 */
class ParseRun {

    /** Tokens to be parsed. */
    private final List<Token> tokens;

    /** The current token index. */
    private int index = 0;

    /**
     * Constructs a run of parsing.
     */
    ParseRun(Locale locale, List<Token> tokens) {
        this.tokens = tokens;
    }

    /**
     * Invokes the run.
     */
    SeqExpr run() throws CompileException {
        return parseProgram();
    }

    // parse program {{{1

    /**
     * Parses a program.
     */
    private SeqExpr parseProgram() throws CompileException {
        int startPos = peekToken().startPos();
        List<Expr> steps = new ArrayList<>();
        while (! (peekToken() instanceof EotToken)) {
            int exprPos = peekToken().startPos();
            if (isMark(peekToken(), "=")) {
                throw new CompileException(getMsgLetOnTopLevel(), exprPos, exprPos);
            }
            Expr expr = parseExpression();
            steps.add(expr);
        }
        return new SeqExpr(steps, startPos);
    }

    // }}}1
    // parse seq {{{1

    /**
     * Parses a seq.
     */
    private SeqExpr parseSeq(
            Predicate<Token> atEnd, boolean disallowsEmpty) throws CompileException {
        int seqPos = peekToken().startPos();
        if (disallowsEmpty && atEnd.test(peekToken())) {
            throw new CompileException(getMsgLetMustBeFollowedByExpression(), seqPos, seqPos);
        }
        List<Expr> steps = new ArrayList<>();
        while (! atEnd.test(peekToken())) {
            int exprPos = peekToken().startPos();
            Expr expr = parseExpression();
            if (isMark(peekToken(), "=")) {
                steps.add(parseLet(expr, exprPos, atEnd));
            } else {
                steps.add(expr);
            }
        }
        return new SeqExpr(steps, seqPos);
    }

    /**
     * Parses let.
     */
    private Expr parseLet(Expr lhs, int lhsPos, Predicate<Token> atEnd) throws CompileException {
        MarkToken letMark = readMarkToken("=");
        int letPos = letMark.startPos();
        Expr lhsVec = new VecExpr(List.of(lhs), letPos);
        Expr rhs = parseExpression();
        SeqExpr subsequent = parseSeq(atEnd, true);
        Expr argVec = new DerefExpr(new BindingExpr(letPos), "_Args", letPos);
        McallExpr assign = new McallExpr(lhsVec, "op_store",
                List.of(argVec), letMark.startPos());
        List<Expr> funBody = List.of(assign, subsequent);
        SeqExpr funSeq = new SeqExpr(funBody, lhsPos);
        FunExpr fun = new FunExpr(funSeq, lhsPos);
        SeqExpr nada = new SeqExpr(List.of(), letPos);
        return new McallExpr(fun, "call", List.of(nada, new VecExpr(List.of(rhs), letPos)), letPos);
    }

    // }}}1
    // parse expression {{{1

    /**
     * Parses an expression.
     */
    private Expr parseExpression() throws CompileException {
        return parseStore();
    }

    // }}}1
    // parse operators {{{1

    /**
     * Parses op_store.
     */
    private Expr parseStore() throws CompileException {
        Expr head = parseLogOr();
        if (! isMark(peekToken(), "<-")) {
            return head;
        }

        MarkToken op = readMarkToken("<-");
        Expr arg = parseLogOr();
        return new McallExpr(head, "op_store",
                List.of(arg),
                op.startPos());
    }

    /**
     * Parses logor_op.
     */
    private Expr parseLogOr() throws CompileException {
        Expr head = parseLogAnd();
        if (! isMark(peekToken(), "||")) {
            return head;
        }

        MarkToken op = (MarkToken) readToken();
        int funStart = peekToken().startPos();
        Expr body = parseLogOr();
        FunExpr fun = new FunExpr(body, funStart);
        BindingExpr binding = new BindingExpr(op.startPos());
        Expr nada = new SeqExpr(List.of(), op.startPos());
        return new RcallExpr(binding, "op_logor", nada,
                List.of(head, fun),
                op.startPos());
    }

    /**
     * Parses logand_op.
     */
    private Expr parseLogAnd() throws CompileException {
        Expr head = parseRel();
        if (! isMark(peekToken(), "&&")) {
            return head;
        }

        MarkToken op = (MarkToken) readToken();
        int funStart = peekToken().startPos();
        Expr body = parseLogAnd();
        FunExpr fun = new FunExpr(body, funStart);
        BindingExpr binding = new BindingExpr(op.startPos());
        Expr nada = new SeqExpr(List.of(), op.startPos());
        return new RcallExpr(binding, "op_logand", nada,
                List.of(head, fun),
                op.startPos());
    }

    /** Relationship operators. */
    private static final Map<String, Function<Expr, Function<Token, Function<Expr, Expr>>>> REL_OPS;

    static {
        Map<String, Function<Expr, Function<Token, Function<Expr, Expr>>>> map = new HashMap<>();
        map.put("==", x -> op -> y -> {
            return new McallExpr(x, "op_eq",
                    List.of(y),
                    op.startPos());
        });
        map.put("!=", x -> op -> y -> {
            Expr eqExpr = new McallExpr(x, "op_eq",
                    List.of(y),
                    op.startPos());
            return new RcallExpr(
                    new BindingExpr(op.startPos()),
                    "op_lognot",
                    new SeqExpr(List.of(), op.startPos()),
                    List.of(eqExpr),
                    op.startPos());
        });
        map.put("<", x -> op -> y -> {
            return new McallExpr(x, "op_lt",
                    List.of(y),
                    op.startPos());
        });
        map.put(">", x -> op -> y -> {
            return new McallExpr(y, "op_lt",
                    List.of(x),
                    op.startPos());
        });
        map.put("<=", x -> op -> y -> {
            Expr ltExpr = new McallExpr(y, "op_lt",
                    List.of(x),
                    op.startPos());
            return new RcallExpr(
                    new BindingExpr(op.startPos()),
                    "op_lognot",
                    new SeqExpr(List.of(), op.startPos()),
                    List.of(ltExpr),
                    op.startPos());
        });
        map.put(">=", x -> op -> y -> {
            Expr ltExpr = new McallExpr(x, "op_lt",
                    List.of(y),
                    op.startPos());
            return new RcallExpr(
                    new BindingExpr(op.startPos()),
                    "op_lognot",
                    new SeqExpr(List.of(), op.startPos()),
                    List.of(ltExpr),
                    op.startPos());
        });
        REL_OPS = Map.copyOf(map);
    }

    /**
     * Parses relation_op.
     */
    private Expr parseRel() throws CompileException {
        Expr x = parseAdd();
        var toRelNode = getOpHandlerOrNull(peekToken(), REL_OPS);
        if (toRelNode == null) {
            return x;
        }

        MarkToken op = (MarkToken) readToken();
        Expr y = parseAdd();
        return toRelNode.apply(x).apply(op).apply(y);
    }

    /** Additive operators. */
    private static final Map<String, String> ADD_OPS;

    static {
        Map<String, String> map = new HashMap<>();
        map.put("+", "op_add");
        map.put("-", "op_sub");
        map.put("|", "op_or");
        map.put("^", "op_xor");
        ADD_OPS = Map.copyOf(map);
    }

    /**
     * Parses add_op.
     */
    private Expr parseAdd() throws CompileException {
        Expr head = parseMultiply();
        while (true) {
            String verb = getOpHandlerOrNull(peekToken(), ADD_OPS);
            if (verb == null) {
                return head;
            }

            MarkToken op = (MarkToken) readToken();
            Expr arg = parseMultiply();
            head = new McallExpr(head, verb, List.of(arg), op.startPos());
        }
    }

    /** Multiplicative operators. */
    private static final Map<String, String> MUL_OPS;

    static {
        Map<String, String> map = new HashMap<>();
        map.put("*", "op_mul");
        map.put("/", "op_div");
        map.put("//", "op_intdiv");
        map.put("%", "op_rem");
        map.put("&", "op_and");
        map.put("<<", "op_shl");
        map.put(">>", "op_shr");
        MUL_OPS = Map.copyOf(map);
    }

    /**
     * Parses mul_op.
     */
    private Expr parseMultiply() throws CompileException {
        Expr head = parseUnary();
        while (true) {
            String verb = getOpHandlerOrNull(peekToken(), MUL_OPS);
            if (verb == null) {
                return head;
            }

            MarkToken op = (MarkToken) readToken();
            Expr arg = parseUnary();
            head = new McallExpr(head, verb, List.of(arg), op.startPos());
        }
    }

    /** Unary operators. */
    private static final Map<String, BiFunction<Expr, Integer, Expr>> UNARY_OPS;

    static {
        Map<String, BiFunction<Expr, Integer, Expr>> map = new HashMap<>();
        map.put("-", (operand, pos) -> new McallExpr(operand, "op_minus", List.of(), pos));
        map.put("~", (operand, pos) -> new McallExpr(operand, "op_not", List.of(), pos));
        map.put("!", (operand, pos) -> {
            Expr nada = new SeqExpr(List.of(), pos);
            BindingExpr binding = new BindingExpr(pos);
            return new RcallExpr(binding, "op_lognot", nada, List.of(operand), pos);
        });
        UNARY_OPS = Map.copyOf(map);
    }

    /**
     * Parses unary_op.
     */
    private Expr parseUnary() throws CompileException {
        var makeNode = getOpHandlerOrNull(peekToken(), UNARY_OPS);
        if (makeNode == null) {
            return parsePrimary();
        }

        MarkToken op = (MarkToken) readToken();
        return makeNode.apply(parseUnary(), op.startPos());
    }

    /**
     * If the token is a mark and the mark is an operator in the keySet of opToHandler,
     * returns the corresponding handler; otherwise, returns null.
     *
     * @param token the token.
     * @param opToHandler the mapping from operators to handlers.
     * @param <H> the type of handlers.
     * @return the handler, or null if absent.
     */
    @Nullable
    private <H> H getOpHandlerOrNull(Token token, Map<String, H> opToHandler) {
        if (! (token instanceof MarkToken)) {
            return null;
        }

        String mark = ((MarkToken) token).mark();
        return opToHandler.get(mark);
    }

    // }}}1
    // parse primary {{{1

    /**
     * Parses primary.
     */
    private Expr parsePrimary() throws CompileException {
        Expr head = parseAtom();
        while (true) {
            if (isMark(peekToken(), ".")) {
                head = parseDot(head);
            } else if (isMarkNotAfterWhitespace(peekToken(), "$")) {
                head = parseAttrVerbDeref(head);
            } else if (isMarkNotAfterWhitespace(peekToken(), ":")) {
                head = parseAttrVarref(head);
            } else {
                return head;
            }
        }
    }

    /**
     * Parses after ".".
     */
    private Expr parseDot(Expr head) throws CompileException {
        readMarkToken(".");
        if (peekToken() instanceof NounToken) {
            return parseAttrNounDeref(head);
        } else if (peekToken() instanceof VerbToken){
            return parseAttrCall(head);
        } else {
            int pos = peekToken().startPos();
            throw new CompileException(getMsgVerbOrNounExpected(), pos, pos);
        }
    }

    /**
     * Parses attributional noun dereference.
     */
    private Expr parseAttrNounDeref(Expr head) throws CompileException {
        NounToken noun = (NounToken) readToken();
        return new DerefExpr(head, noun.noun(), noun.startPos());
    }

    /**
     * Parses after DOLLAR.
     */
    private Expr parseAttrVerbDeref(Expr head) throws CompileException {
        readMarkToken("$");
        VerbToken verb = readVerbToken();
        return new DerefExpr(head, verb.verb(), verb.startPos());
    }

    /**
     * Parses after COLON.
     */
    private Expr parseAttrVarref(Expr head) throws CompileException {
        readMarkToken(":");
        Token symToken = readToken();
        String sym = asSym(symToken);
        return new VarrefExpr(head, sym, symToken.startPos());
    }

    // }}}1
    // parse atoms {{{1

    /**
     * Parses an atom expression.
     */
    private Expr parseAtom() throws CompileException {
        Token token = peekToken();
        if (token instanceof StrToken) {
            return parseStr();
        } else if (token instanceof NumToken) {
            return parseNum();
        } else if (token instanceof VerbToken) {
            return parseLocalCall();
        } else if (token instanceof NounToken) {
            return parseLocalNounDeref();
        } else if (isMark(token, "$")) {
            return parseLocalVerbDeref();
        } else if (isMark(token, ":")) {
            return parseLocalVarref();
        } else if (isMark(token, "(")) {
            return parseParen();
        } else if (isMark(token, "{")) {
            return parseFun();
        } else if (isMark(token, "[")) {
            return parseVec();
        } else if (isMark(token, "\\binding")) {
            readMarkToken("\\binding");
            return new BindingExpr(token.startPos());
        } else if (token instanceof EotToken) {
            int errPos = token.startPos();
            throw new CompileException(getMsgUnexpectedEndOfText(), errPos, errPos);
        } else {
            int errPos = token.startPos();
            throw new CompileException(getMsgUnexpectedToken(), errPos, errPos);
        }
    }

    /**
     * Parses a str expression.
     */
    private StrExpr parseStr() throws CompileException {
        StrToken st = (StrToken) readToken();
        return new StrExpr(st.value(), st.startPos());
    }

    /**
     * Parses a num expression.
     */
    private NumExpr parseNum() throws CompileException {
        NumToken nt = (NumToken) readToken();
        return new NumExpr(nt.decimal(), nt.startPos());
    }

    /**
     * Parses paren.
     */
    private Expr parseParen() throws CompileException {
        readMarkToken("(");
        Expr seq = parseSeq(t -> isMark(t, ")"), false);
        readMarkToken(")");
        return seq;
    }

    /**
     * Parses a local nound dereference.
     */
    private Expr parseLocalNounDeref() throws CompileException {
        NounToken noun = (NounToken) readToken();
        BindingExpr binding = new BindingExpr(noun.startPos());
        return new DerefExpr(binding, noun.noun(), noun.startPos());
    }

    /**
     * Parses a local verb dereference.
     */
    private Expr parseLocalVerbDeref() throws CompileException {
        MarkToken dollar = readMarkToken("$");
        VerbToken verb = readVerbToken();
        BindingExpr binding = new BindingExpr(dollar.startPos());
        return new DerefExpr(binding, verb.verb(), verb.startPos());
    }

    /**
     * Parses a local varref.
     */
    private Expr parseLocalVarref() throws CompileException {
        MarkToken colon = readMarkToken(":");
        Token symToken = readToken();
        String sym = asSym(symToken);
        BindingExpr binding = new BindingExpr(colon.startPos());
        return new VarrefExpr(binding, sym, symToken.startPos());
    }

    // }}}1
    // parse calls {{{1

    /**
     * Parses local_call.
     */
    private Expr parseLocalCall() throws CompileException {
        VerbToken verb = readVerbToken();
        Expr recv = parseActualRecv()
            .orElseGet(() -> new SeqExpr(List.of(), verb.startPos()));
        List<Elem> args = parseActualArgs();
        BindingExpr binding = new BindingExpr(verb.startPos());
        return new RcallExpr(binding, verb.verb(), recv, args, verb.startPos());
    }

    /**
     * Parses attr_call.
     */
    private Expr parseAttrCall(Expr owner) throws CompileException {
        VerbToken verbToken = readVerbToken();
        Optional<Expr> optRecv = parseActualRecv();
        List<Elem> args = parseActualArgs();

        String verb = verbToken.verb();
        int pos = verbToken.startPos();
        return optRecv.map(recv -> (Expr) new RcallExpr(owner, verb, recv, args, pos))
            .orElseGet(() -> new McallExpr(owner, verb, args, pos));
    }

    /**
     * Parses an actual recv.
     */
    private Optional<Expr> parseActualRecv() throws CompileException {
        if (! isMarkNotAfterWhitespace(peekToken(), "[")) {
            return Optional.empty();
        }

        readMarkToken("[");
        Expr recv = parseExpression();
        readMarkToken("]");
        return Optional.of(recv);
    }

    /**
     * Parses a list of actual args.
     */
    private List<Elem> parseActualArgs() throws CompileException {
        List<Elem> args = new ArrayList<>();

        if (isMarkNotAfterWhitespace(peekToken(), "(")) {
            readMarkToken("(");
            args.addAll(parseVecBody(t -> isMark(t, ")")));
            readMarkToken(")");
        }

        while (isMarkNotAfterWhitespace(peekToken(), "{")) {
            args.add(parseFun());
        }

        return args;
    }

    // }}}1
    // parse fun {{{1

    /**
     * Parses fun.
     */
    private FunExpr parseFun() throws CompileException {
        MarkToken opener = readMarkToken("{");
        int seqStart = peekToken().startPos();
        List<Expr> exprs = new ArrayList<>();

        if (isMarkNotAfterWhitespace(peekToken(), "[")) {
            exprs.add(parseRecvPassing());
        }

        if (isMarkNotAfterWhitespace(peekToken(), "(")) {
            exprs.add(parseArgVecPassing());
        }

        SeqExpr body = parseSeq(t -> isMark(t, "}"), false);
        exprs.add(body);

        readMarkToken("}");
        SeqExpr seq = new SeqExpr(exprs, seqStart);
        return new FunExpr(seq, opener.startPos());
    }

    /**
     * Parses a recv passing.
     */
    private McallExpr parseRecvPassing() throws CompileException {
        MarkToken open = readMarkToken("[");
        Expr lhs = parseExpression();
        readMarkToken("]");
        int pos = open.startPos();
        Expr rhs = new DerefExpr(new BindingExpr(pos), "_Recv", pos);
        return new McallExpr(lhs, "op_store", List.of(rhs), open.startPos());
    }

    /**
     * Parses an arg vec passing.
     */
    private McallExpr parseArgVecPassing() throws CompileException {
        MarkToken open = readMarkToken("(");
        int pos = open.startPos();
        List<Elem> lhsElems = parseVecBody(t -> isMark(t, ")"));
        readMarkToken(")");
        VecExpr lhs = new VecExpr(lhsElems, pos);
        Expr rhs = new DerefExpr(new BindingExpr(pos), "_Args", pos);
        return new McallExpr(lhs, "op_store", List.of(rhs), pos);
    }

    // }}}1
    // parse vec {{{1

    /**
     * Parses vec.
     */
    private Expr parseVec() throws CompileException {
        MarkToken open = readMarkToken("[");
        List<Elem> elems = parseVecBody(t -> isMark(t, "]"));
        VecExpr vec = new VecExpr(elems, open.startPos());
        readMarkToken("]");
        return vec;
    }

    /**
     * Parses vec_body.
     */
    private List<Elem> parseVecBody(Predicate<Token> atEnd) throws CompileException {
        List<Elem> elems = new ArrayList<>();
        while (! atEnd.test(peekToken())) {
            Token token = peekToken();
            elems.add(isMark(token, "...")
                    ? parseSpreader()
                    : parseExpression());
        }
        return elems;
    }

    /**
     * Parses an spreader.
     */
    private Elem parseSpreader() throws CompileException {
        MarkToken dots = readMarkToken("...");
        Expr expr = parseExpression();
        return new Elem.Spread(expr, dots.startPos());
    }

    // }}}1
    // handling tokens {{{1

    /**
     * Returns the current token.
     */
    private Token peekToken() throws CompileException {
        Token token = tokens.get(index);
        if (token instanceof ErrorToken) {
            ErrorToken et = (ErrorToken) token;
            throw new CompileException(et.msg(), et.startPos(), et.endPos());
        }
        return token;
    }

    /**
     * Returns the current token and increments the index.
     */
    private Token readToken() throws CompileException {
        Token token = peekToken();
        ++ this.index;
        return token;
    }

    /**
     * Checks that the current token is the specified mark,
     * and returns it.
     */
    private MarkToken readMarkToken(String mark) throws CompileException {
        Token token = readToken();
        if (! isMark(token, mark)) {
            int pos = token.startPos();
            throw new CompileException(getMsgTokenIsNot(mark), pos, pos);
        }
        return (MarkToken) token;
    }

    /**
     * Checks that the current token is a verb,
     * and returns it.
     */
    private VerbToken readVerbToken() throws CompileException {
        Token token = readToken();
        if (! (token instanceof VerbToken)) {
            int pos = token.startPos();
            throw new CompileException(getMsgVerbExpected(), pos, pos);
        }
        return (VerbToken) token;
    }

    /**
     * Checks that the given token is a verb or a noun,
     * and returns the sym.
     */
    private String asSym(Token token) throws CompileException {
        if (token instanceof NounToken) {
            return ((NounToken) token).noun();
        } else if (token instanceof VerbToken) {
            return ((VerbToken) token).verb();
        } else {
            int pos = token.startPos();
            throw new CompileException(getMsgVerbOrNounExpected(), pos, pos);
        }
    }

    /**
     * Returns true if the token is a mark token with the specified mark.
     */
    private boolean isMark(Token token, String mark) {
        return token instanceof MarkToken
            && ((MarkToken) token).mark().equals(mark);
    }

    /**
     * Returns true if the token is a mark token with the specified mark,
     * and it is placed not after whitespace.
     */
    private boolean isMarkNotAfterWhitespace(Token token, String mark) {
        return isMark(token, mark)
            && ! ((MarkToken) token).isAfterWhitespace();
    }

    // }}}1
    // messages {{{1

    /**
     * Returns a message which indicates an unexpected '=' on the top level.
     */
    String getMsgLetOnTopLevel() {
        return "unexpected '=' on the top level; do you mean '<-'?";
    }

    /**
     * Returns an error message which indicates a let clause without a following expression.
     */
    String getMsgLetMustBeFollowedByExpression() {
        return "let clause like «:X = Val» must be followed by an expression";
    }

    /**
     * Returns a message which indicates an unexpected token.
     */
    String getMsgUnexpectedToken() {
        return "unexpected token";
    }

    /**
     * Returns a message which indicates an unexpected end-of-text.
     */
    String getMsgUnexpectedEndOfText() {
        return "unexpected end of program text";
    }

    /**
     * Returns a message which indicates that a verb is expected.
     */
    String getMsgVerbExpected() {
        return "expected a verb such as «foo»";
    }

    /**
     * Returns a message which indicates that a verb or a noun is expected.
     */
    String getMsgVerbOrNounExpected() {
        return "expected a verb such as «foo» or a noun such as «Bar»";
    }

    /**
     * Returns a message which indicates the token is not the expected mark.
     */
    String getMsgTokenIsNot(String mark) {
        return String.format(Locale.ROOT, "expected mark: %s", mark);
    }

    // }}}1

}

// vim: et sw=4 sts=4 fdm=marker
