package org.kink_lang.kink.internal.program.itree;

import java.util.ArrayList;
import java.util.List;

import org.kink_lang.kink.internal.contract.Preconds;

/**
 * Transforms an itree traversing the sub itrees in the evaluation order.
 *
 * <p>Note that body itrees in funs are not traversed.</p>
 */
public final class DeepTransformer {

    /**
     * Cannot be instantiated.
     */
    private DeepTransformer() {
    }

    /**
     * Transforms {@code itree} traversing the sub itrees in the evaluation order.
     *
     * <p>Note that body itrees of funs are not traversed.</p>
     *
     * @param itree the itree to be transformed.
     * @param callback the callback.
     * @return the result of transformation.
     */
    public static Itree deepTransform(Itree itree, Callback callback) {
        return itree.accept(new Visitor(callback));
    }

    /**
     * Callback for {@link DeepTransformer#deepTransform(Itree, Callback)}.
     */
    public interface Callback {

        /**
         * Transforms dereference of a local var.
         *
         * <p>Dereference of {@link NoTraitNewValItree}, {@link TraitNewValItree},
         * {@link IfItree}, {@link BranchItree}, and {@link BranchWithElseItree}
         * must not be transformed.</p>
         *
         * @param lvar the local var.
         * @return the result of transformation.
         */
        LocalVar derefLvar(LocalVar lvar);

        /**
         * Transforms storing to a local var.
         *
         * @param lvar the original local var to be transformed.
         * @return the result of transformation.
         */
        LocalVar storeLvar(LocalVar lvar);

        /**
         * Transforms the itree.
         *
         * <p>If {@code itree} is {@link FastFunItree},
         * the result must also be {@link FastFunItree}.</p>
         *
         * @param itree the itree to be transformed.
         * @return the result of transformation.
         */
        Itree itree(Itree itree);

    }

    /**
     * Visitor for traversal.
     */
    private static class Visitor extends SkeltonItreeVisitor<Itree> {

        /** Local var for {@code new_val}. */
        private static final LocalVar NEW_VAL_LVAR = new LocalVar.Original("new_val");

        /** Local var for {@code if}. */
        private static final LocalVar IF_LVAR = new LocalVar.Original("if");

        /** Local var for {@code branch}. */
        private static final LocalVar BRANCH_LVAR = new LocalVar.Original("branch");

        /** Local var for {@code true}. */
        private static final LocalVar TRUE_LVAR = new LocalVar.Original("true");

        /** The callback. */
        private final Callback callback;

        /**
         * Makes a visitor.
         */
        Visitor(Callback callback) {
            super(callback::itree);
            this.callback = callback;
        }

        /**
         * Transforms a sub itree.
         */
        private Itree transformSub(Itree itree) {
            return itree.accept(this);
        }

        /**
         * Transforms an itree elem.
         */
        private ItreeElem transformElem(ItreeElem elem) {
            return elem instanceof Itree itree
                ? transformSub(itree)
                : new ItreeElem.Spread(transformSub(elem.expr()), elem.pos());
        }

        /**
         * Transforms lvar stores.
         */
        private List<LocalVar> transformStoreLvars(List<LocalVar> lvars) {
            return lvars.stream()
                .map(this.callback::storeLvar)
                .toList();
        }

        /**
         * Callbacks lvar stores in nested params.
         */
        private List<NestedParam> transformNestedParams(List<NestedParam> params) {
            return params.stream()
                .map(p -> p instanceof NestedParam.Tuple t
                        ? new NestedParam.Tuple(transformStoreLvars(t.lvars()))
                        : this.callback.storeLvar((LocalVar) p))
                .toList();
        }

        @Override
        public Itree visit(SeqItree itree) {
            var steps = itree.steps().stream()
                .map(this::transformSub)
                .toList();
            return this.callback.itree(new SeqItree(steps, itree.pos()));
        }

        @Override
        public FastFunItree visit(FastFunItree itree) {
            return (FastFunItree) this.callback.itree(itree);
        }

        @Override
        public Itree visit(VecItree itree) {
            var elems = itree.elems().stream()
                .map(this::transformElem)
                .toList();
            return this.callback.itree(new VecItree(elems, itree.pos()));
        }

        @Override
        public Itree visit(DerefItree itree) {
            Itree owner = transformSub(itree.owner());
            return this.callback.itree(new DerefItree(owner, itree.sym(), itree.pos()));
        }

        @Override
        public Itree visit(LderefItree itree) {
            var lvar = this.callback.derefLvar(itree.lvar());
            return this.callback.itree(new LderefItree(lvar, itree.pos()));
        }

        @Override
        public Itree visit(VarrefItree itree) {
            Itree owner = transformSub(itree.owner());
            return this.callback.itree(new VarrefItree(owner, itree.sym(), itree.pos()));
        }

        @Override
        public Itree visit(LetRecItree itree) {
            // lvars must be stored first
            var lvars = itree.lvarFunPairs().stream()
                .map(LetRecItree.LvarFunPair::lvar)
                .map(this.callback::storeLvar)
                .toList();

            // then funs are created
            var funs = itree.lvarFunPairs().stream()
                .map(pair -> (FastFunItree) transformSub(pair.fun()))
                .toList();

            List<LetRecItree.LvarFunPair> pairs = new ArrayList<>();
            for (int i = 0; i < lvars.size(); ++ i) {
                pairs.add(new LetRecItree.LvarFunPair(lvars.get(i), funs.get(i)));
            }
            return this.callback.itree(new LetRecItree(pairs, itree.pos()));
        }

        @Override
        public Itree visit(AssignmentItree itree) {
            var lhs = transformSub(itree.lhs());
            var rhs = transformSub(itree.rhs());
            return this.callback.itree(new AssignmentItree(lhs, rhs, itree.pos()));
        }

        @Override
        public Itree visit(OptVecAssignmentItree itree) {
            var rhs = transformSub(itree.rhs());
            var mandatory = itree.mandatory().stream()
                .map(callback::storeLvar)
                .toList();
            var opt = itree.opt().stream()
                .map(callback::storeLvar)
                .toList();
            return this.callback.itree(new OptVecAssignmentItree(
                        mandatory,
                        opt,
                        rhs,
                        itree.pos()));
        }

        @Override
        public Itree visit(OptRestVecAssignmentItree itree) {
            var rhs = transformSub(itree.rhs());
            var mandatory = itree.mandatory().stream()
                .map(callback::storeLvar)
                .toList();
            var opt = itree.opt().stream()
                .map(callback::storeLvar)
                .toList();
            var rest = this.callback.storeLvar(itree.rest());
            return this.callback.itree(new OptRestVecAssignmentItree(
                        mandatory,
                        opt,
                        rest,
                        rhs,
                        itree.pos()));
        }

        @Override
        public Itree visit(RestVecAssignmentItree itree) {
            var rhs = transformSub(itree.rhs());
            var lvar = this.callback.storeLvar(itree.lvar());
            return this.callback.itree(new RestVecAssignmentItree(
                        lvar,
                        rhs,
                        itree.pos()));
        }

        @Override
        public Itree visit(NestedVecAssignmentItree itree) {
            var rhs = transformSub(itree.rhs());
            var params = transformNestedParams(itree.params());
            return this.callback.itree(new NestedVecAssignmentItree(
                        params,
                        rhs,
                        itree.pos()));
        }

        @Override
        public Itree visit(VarrefVecAssignmentItree itree) {
            var genericEvaluated = itree.params().stream()
                .map(p -> p instanceof GenericVar g
                        ? new GenericVar(transformSub(g.owner()), g.name())
                        : p)
                .toList();
            var rhs = transformSub(itree.rhs());
            var params = genericEvaluated.stream()
                .map(p -> p instanceof LocalVar lvar
                        ? this.callback.storeLvar(lvar)
                        : p)
                .toList();
            return this.callback.itree(new VarrefVecAssignmentItree(
                        params,
                        rhs,
                        itree.pos()));
        }

        @Override
        public Itree visit(LstoreItree itree) {
            var rhs = transformSub(itree.rhs());
            var lvar = this.callback.storeLvar(itree.lvar());
            return this.callback.itree(new LstoreItree(
                        lvar,
                        rhs,
                        itree.pos()));
        }

        @Override
        public Itree visit(StoreItree itree) {
            var owner = transformSub(itree.owner());
            var rhs = transformSub(itree.rhs());
            return this.callback.itree(new StoreItree(
                        owner,
                        itree.sym(),
                        rhs,
                        itree.pos()));
        }

        @Override
        public Itree visit(ArgsPassingItree itree) {
            var lvars = transformStoreLvars(itree.lvars());
            return this.callback.itree(new ArgsPassingItree(lvars, itree.pos()));
        }

        @Override
        public Itree visit(NestedArgsPassingItree itree) {
            var params = transformNestedParams(itree.params());
            return this.callback.itree(new NestedArgsPassingItree(params, itree.pos()));
        }

        @Override
        public Itree visit(BiArithmeticItree itree) {
            var recv = transformSub(itree.recv());
            return this.callback.itree(new BiArithmeticItree(
                        recv,
                        itree.op(),
                        itree.arg(),
                        itree.pos()));
        }

        @Override
        public Itree visit(NoTraitNewValItree itree) {
            var lvar = this.callback.derefLvar(NEW_VAL_LVAR);
            Preconds.checkState(lvar.equals(NEW_VAL_LVAR), "sym of new_val must not change");
            var symValPairs = itree.symValPairs().stream()
                .map(symVal -> new SymValPair(symVal.sym(), transformSub(symVal.val())))
                .toList();
            return this.callback.itree(new NoTraitNewValItree(
                        symValPairs,
                        itree.pos()));
        }

        @Override
        public Itree visit(TraitNewValItree itree) {
            var lvar = this.callback.derefLvar(NEW_VAL_LVAR);
            Preconds.checkState(lvar.equals(NEW_VAL_LVAR), "sym of new_val must not change");
            var trait = transformSub(itree.trait());
            var symValPairs = itree.symValPairs().stream()
                .map(pair -> new SymValPair(pair.sym(), transformSub(pair.val())))
                .toList();
            return this.callback.itree(new TraitNewValItree(
                        trait,
                        itree.spreadPos(),
                        symValPairs,
                        itree.pos()));
        }

        @Override
        public Itree visit(IfItree itree) {
            var lvar = this.callback.derefLvar(IF_LVAR);
            Preconds.checkState(lvar.equals(IF_LVAR), "sym of if must not change");
            var cond = transformSub(itree.cond());
            var trueFun = (FastFunItree) transformSub(itree.trueFun());
            var falseFun = itree.falseFun().map(f -> (FastFunItree) transformSub(f));
            return this.callback.itree(new IfItree(
                        cond,
                        trueFun,
                        falseFun,
                        itree.pos()));
        }

        @Override
        public Itree visit(BranchItree itree) {
            var lvar = this.callback.derefLvar(BRANCH_LVAR);
            Preconds.checkState(lvar.equals(BRANCH_LVAR), "sym of branch must not change");
            var condThenPairs = itree.condThenPairs().stream()
                .map(condThen -> {
                    return new CondThenPair(
                            (FastFunItree) transformSub(condThen.condFun()),
                            (FastFunItree) transformSub(condThen.thenFun()));
                })
                .toList();
            return this.callback.itree(new BranchItree(
                        condThenPairs,
                        itree.pos()));
        }

        @Override
        public Itree visit(BranchWithElseItree itree) {
            var branchLvar = this.callback.derefLvar(BRANCH_LVAR);
            Preconds.checkState(branchLvar.equals(BRANCH_LVAR),
                    "sym of branch must not change");
            var trueLvar = this.callback.derefLvar(TRUE_LVAR);
            Preconds.checkState(trueLvar.equals(TRUE_LVAR),
                    "sym of true must not change");
            var condThenPairs = itree.condThenPairs().stream()
                . map(condThen -> new CondThenPair(
                            (FastFunItree) transformSub(condThen.condFun()),
                            (FastFunItree) transformSub(condThen.thenFun())))
                .toList();
            var elseThenFun = (FastFunItree) transformSub(itree.elseThenFun());
            return this.callback.itree(new BranchWithElseItree(
                        condThenPairs,
                        elseThenFun,
                        itree.pos()));
        }

        @Override
        public Itree visit(McallItree itree) {
            var ownerRecv = transformSub(itree.ownerRecv());
            var args = itree.args().stream()
                .map(this::transformElem)
                .toList();
            return this.callback.itree(new McallItree(
                        ownerRecv,
                        itree.sym(),
                        args,
                        itree.pos()));
        }

        @Override
        public Itree visit(SymcallItree itree) {
            var fun = transformSub(itree.fun());
            var recv = transformSub(itree.recv());
            var args = itree.args().stream()
                .map(this::transformElem)
                .toList();
            return this.callback.itree(new SymcallItree(
                        fun,
                        itree.sym(),
                        recv,
                        args,
                        itree.pos()));
        }

    }

}

// vim: et sw=4 sts=4 fdm=marker
