/*
 * Decompiled with CFR 0.152.
 */
package org.aya.core.visitor;

import kala.collection.SeqLike;
import kala.collection.Set;
import kala.collection.immutable.ImmutableSeq;
import kala.collection.mutable.MutableMap;
import kala.collection.mutable.MutableSet;
import kala.control.Either;
import kala.tuple.Tuple2;
import kala.tuple.Unit;
import org.aya.api.ref.DefVar;
import org.aya.api.ref.Var;
import org.aya.api.util.Arg;
import org.aya.api.util.WithPos;
import org.aya.concrete.stmt.Decl;
import org.aya.core.Matching;
import org.aya.core.Meta;
import org.aya.core.def.CtorDef;
import org.aya.core.def.Def;
import org.aya.core.def.FieldDef;
import org.aya.core.def.FnDef;
import org.aya.core.def.PrimDef;
import org.aya.core.pat.PatMatcher;
import org.aya.core.sort.LevelSubst;
import org.aya.core.sort.Sort;
import org.aya.core.term.CallTerm;
import org.aya.core.term.IntroTerm;
import org.aya.core.term.Term;
import org.aya.core.visitor.Substituter;
import org.aya.core.visitor.TermFixpoint;
import org.aya.tyck.TyckState;
import org.jetbrains.annotations.Contract;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

public interface Unfolder<P>
extends TermFixpoint<P> {
    @Nullable
    public TyckState state();

    @Contract(pure=true)
    @NotNull
    public static Substituter.TermSubst buildSubst(@NotNull @NotNull SeqLike<@NotNull Term.Param> self, @NotNull @NotNull SeqLike<@NotNull Arg<@NotNull Term>> args) {
        Substituter.TermSubst subst = new Substituter.TermSubst((MutableMap<Var, Term>)MutableMap.create());
        self.view().zip(args).forEach(t -> subst.add((Var)((Term.Param)t._1).ref(), (Term)((Arg)t._2).term()));
        return subst;
    }

    @Override
    @NotNull
    default public Term visitConCall(@NotNull CallTerm.Con conCall, P p) {
        CtorDef def = (CtorDef)conCall.ref().core;
        if (def == null) {
            return conCall;
        }
        ImmutableSeq args = conCall.args().map(arg -> this.visitArg((Arg<Term>)arg, p));
        ImmutableSeq<Sort.LvlVar> levelParams = Def.defLevels(def.ref());
        ImmutableSeq<Sort> levelArgs = conCall.sortArgs();
        LevelSubst levelSubst = Unfolder.buildSubst(levelParams, levelArgs);
        ImmutableSeq dropped = args.drop(conCall.head().dataArgs().size());
        WithPos<Term> volynskaya = this.tryUnfoldClauses(p, (SeqLike<Arg<Term>>)dropped, levelSubst, (ImmutableSeq<Matching>)def.clauses);
        return volynskaya != null ? (Term)volynskaya.data() : new CallTerm.Con(conCall.head(), (ImmutableSeq<Arg<Term>>)dropped);
    }

    @NotNull
    public static LevelSubst buildSubst(ImmutableSeq<Sort.LvlVar> levelParams, ImmutableSeq<@NotNull Sort> levelArgs) {
        LevelSubst.Simple levelSubst = new LevelSubst.Simple((MutableMap<Sort.LvlVar, Sort>)MutableMap.create());
        if (!1.$assertionsDisabled && !levelParams.sizeEquals(levelArgs)) {
            throw new AssertionError();
        }
        for (Tuple2 app : levelArgs.zip(levelParams)) {
            levelSubst.solution().put((Object)((Sort.LvlVar)app._2), (Object)((Sort)app._1));
        }
        return levelSubst;
    }

    @Override
    @NotNull
    default public Term visitFnCall(@NotNull CallTerm.Fn fnCall, P p) {
        FnDef def = (FnDef)fnCall.ref().core;
        if (def == null) {
            return fnCall;
        }
        ImmutableSeq args = fnCall.args().map(arg -> this.visitArg((Arg<Term>)arg, p));
        LevelSubst levelSubst = Unfolder.buildSubst((ImmutableSeq<Sort.LvlVar>)def.levels, fnCall.sortArgs());
        Either<Term, ImmutableSeq<Matching>> body = def.body;
        if (body.isLeft()) {
            Substituter.TermSubst termSubst = this.checkAndBuildSubst((SeqLike<Term.Param>)def.telescope(), (SeqLike<Arg<Term>>)args);
            return ((Term)((Term)body.getLeftValue()).subst(termSubst, levelSubst).accept(this, p)).rename();
        }
        WithPos<Term> volynskaya = this.tryUnfoldClauses(p, (SeqLike<Arg<Term>>)args, levelSubst, (ImmutableSeq<Matching>)((ImmutableSeq)body.getRightValue()));
        return volynskaya != null ? (Term)((Term)volynskaya.data()).accept(this, p) : new CallTerm.Fn(fnCall.ref(), fnCall.sortArgs(), (ImmutableSeq<Arg<Term>>)args);
    }

    @NotNull
    private Substituter.TermSubst checkAndBuildSubst(SeqLike<Term.Param> telescope, SeqLike<Arg<Term>> args) {
        return Unfolder.buildSubst(telescope, args);
    }

    @Override
    @NotNull
    default public Term visitPrimCall(@NotNull CallTerm.Prim prim, P p) {
        return ((PrimDef)prim.ref().core).unfold(prim, this.state());
    }

    @Override
    @NotNull
    default public Term visitHole(@NotNull CallTerm.Hole hole, P p) {
        Meta def = hole.ref();
        TyckState state = this.state();
        if (state == null) {
            return hole;
        }
        MutableMap<Meta, Term> metas = state.metas();
        if (!metas.containsKey((Object)def)) {
            return hole;
        }
        Term body = (Term)metas.get((Object)def);
        ImmutableSeq args = hole.fullArgs().map(arg -> this.visitArg((Arg<Term>)arg, p)).toImmutableSeq();
        Substituter.TermSubst subst = this.checkAndBuildSubst((SeqLike<Term.Param>)def.fullTelescope(), (SeqLike<Arg<Term>>)args);
        return (Term)body.subst(subst).accept(this, p);
    }

    @Nullable
    default public WithPos<Term> tryUnfoldClauses(P p, SeqLike<Arg<Term>> args, LevelSubst levelSubst, @NotNull ImmutableSeq<Matching> clauses) {
        return this.tryUnfoldClauses(p, args, new Substituter.TermSubst((MutableMap<Var, Term>)MutableMap.create()), levelSubst, clauses);
    }

    @Nullable
    default public WithPos<Term> tryUnfoldClauses(P p, SeqLike<Arg<Term>> args, @NotNull Substituter.TermSubst subst, LevelSubst levelSubst, @NotNull ImmutableSeq<Matching> clauses) {
        for (Matching matchy : clauses) {
            Substituter.TermSubst termSubst = PatMatcher.tryBuildSubstArgs(matchy.patterns(), args);
            if (termSubst == null) continue;
            subst.add(termSubst);
            Term newBody = ((Term)matchy.body().subst(subst, levelSubst).accept(this, p)).rename();
            return new WithPos(matchy.sourcePos(), (Object)newBody);
        }
        return null;
    }

    @Override
    @NotNull
    default public Term visitAccess(@NotNull CallTerm.Access term, P p) {
        Term nevv = (Term)term.of().accept(this, p);
        DefVar<FieldDef, Decl.StructField> field = term.ref();
        FieldDef core = (FieldDef)field.core;
        if (!(nevv instanceof IntroTerm.New)) {
            ImmutableSeq args = term.args().map(arg -> this.visitArg((Arg<Term>)arg, p));
            Substituter.TermSubst fieldSubst = this.checkAndBuildSubst((SeqLike<Term.Param>)core.fullTelescope(), (SeqLike<Arg<Term>>)args);
            LevelSubst levelSubst = Unfolder.buildSubst(Def.defLevels(field), term.sortArgs());
            ImmutableSeq dropped = args.drop(term.structArgs().size());
            WithPos<Term> mischa = this.tryUnfoldClauses(p, (SeqLike<Arg<Term>>)dropped, fieldSubst, levelSubst, (ImmutableSeq<Matching>)core.clauses);
            return mischa != null ? (Term)mischa.data() : new CallTerm.Access(nevv, field, term.sortArgs(), term.structArgs(), (ImmutableSeq<Arg<Term>>)dropped);
        }
        IntroTerm.New n = (IntroTerm.New)nevv;
        Substituter.TermSubst arguments = Unfolder.buildSubst((SeqLike<Term.Param>)core.ownerTele, term.structArgs());
        Term fieldBody = (Term)term.fieldArgs().foldLeft((Object)((Term)n.params().get(field)), CallTerm::make);
        return (Term)fieldBody.subst(arguments).accept(this, p);
    }

    static {
        if (1.$assertionsDisabled) {
            // empty if block
        }
    }

    public record Tracked(@NotNull @NotNull Set<@NotNull Var> unfolding, @NotNull @NotNull MutableSet<@NotNull Var> unfolded, @Nullable TyckState state, @NotNull PrimDef.Factory factory) implements Unfolder<Unit>
    {
        @Override
        @NotNull
        public Term visitFnCall(@NotNull CallTerm.Fn fnCall, Unit unit) {
            if (!this.unfolding.contains(fnCall.ref())) {
                return fnCall;
            }
            this.unfolded.add(fnCall.ref());
            return Unfolder.super.visitFnCall(fnCall, unit);
        }

        @Override
        @NotNull
        public Term visitConCall(@NotNull CallTerm.Con conCall, Unit unit) {
            if (!this.unfolding.contains(conCall.ref())) {
                return conCall;
            }
            this.unfolded.add(conCall.ref());
            return Unfolder.super.visitConCall(conCall, unit);
        }
    }
}

