/*
 * Decompiled with CFR 0.152.
 */
package org.aya.core.visitor;

import kala.collection.SeqLike;
import kala.collection.Set;
import kala.collection.immutable.ImmutableSeq;
import kala.collection.mutable.MutableMap;
import kala.collection.mutable.MutableSet;
import kala.control.Either;
import kala.control.Result;
import kala.tuple.Unit;
import org.aya.concrete.stmt.Decl;
import org.aya.core.Matching;
import org.aya.core.Meta;
import org.aya.core.def.CtorDef;
import org.aya.core.def.FieldDef;
import org.aya.core.def.FnDef;
import org.aya.core.def.PrimDef;
import org.aya.core.def.StructDef;
import org.aya.core.pat.PatMatcher;
import org.aya.core.term.CallTerm;
import org.aya.core.term.IntroTerm;
import org.aya.core.term.Term;
import org.aya.core.visitor.Subst;
import org.aya.core.visitor.TermFixpoint;
import org.aya.generic.Arg;
import org.aya.generic.Modifier;
import org.aya.ref.DefVar;
import org.aya.ref.Var;
import org.aya.tyck.TyckState;
import org.aya.util.error.WithPos;
import org.jetbrains.annotations.Contract;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

public interface Unfolder<P>
extends TermFixpoint<P> {
    @Nullable
    public TyckState state();

    @Contract(pure=true)
    @NotNull
    public static Subst buildSubst(@NotNull @NotNull SeqLike<@NotNull Term.Param> self, @NotNull @NotNull SeqLike<@NotNull Arg<@NotNull Term>> args) {
        Subst subst = new Subst((MutableMap<Var, Term>)MutableMap.create());
        self.view().zip(args).forEach(t -> subst.add(((Term.Param)t._1).ref(), (Term)((Arg)t._2).term()));
        return subst;
    }

    @Override
    @NotNull
    default public Term visitConCall(@NotNull CallTerm.Con conCall, P p) {
        CtorDef def = (CtorDef)((DefVar)conCall.ref()).core;
        if (def == null) {
            return conCall;
        }
        ImmutableSeq args = conCall.args().map(arg -> this.visitArg((Arg<Term>)arg, p));
        int ulift = this.ulift() + conCall.ulift();
        ImmutableSeq dropped = args.drop(conCall.head().dataArgs().size());
        WithPos<Term> volynskaya = this.tryUnfoldClauses(p, true, (SeqLike<Arg<Term>>)dropped, ulift, (ImmutableSeq<Matching>)def.clauses);
        return volynskaya != null ? (Term)volynskaya.data() : new CallTerm.Con(conCall.head(), (ImmutableSeq<Arg<Term>>)dropped);
    }

    @Override
    @NotNull
    default public Term visitFnCall(@NotNull CallTerm.Fn fnCall, P p) {
        FnDef def = (FnDef)((DefVar)fnCall.ref()).core;
        if (def == null) {
            return fnCall;
        }
        ImmutableSeq args = fnCall.args().map(arg -> this.visitArg((Arg<Term>)arg, p));
        int ulift = this.ulift() + fnCall.ulift();
        if (def.modifiers.contains((Object)Modifier.Opaque)) {
            return new CallTerm.Fn((DefVar<FnDef, Decl.FnDecl>)fnCall.ref(), ulift, (ImmutableSeq<Arg<Term>>)args);
        }
        Either<Term, ImmutableSeq<Matching>> body = def.body;
        if (body.isLeft()) {
            Subst termSubst = this.checkAndBuildSubst((SeqLike<Term.Param>)def.telescope(), (SeqLike<Arg<Term>>)args);
            return ((Term)((Term)body.getLeftValue()).subst(termSubst, ulift).accept(this, p)).rename();
        }
        boolean orderIndependent = def.modifiers.contains((Object)Modifier.Overlap);
        WithPos<Term> volynskaya = this.tryUnfoldClauses(p, orderIndependent, (SeqLike<Arg<Term>>)args, ulift, (ImmutableSeq<Matching>)((ImmutableSeq)body.getRightValue()));
        return volynskaya != null ? (Term)((Term)volynskaya.data()).accept(this, p) : new CallTerm.Fn((DefVar<FnDef, Decl.FnDecl>)fnCall.ref(), ulift, (ImmutableSeq<Arg<Term>>)args);
    }

    @NotNull
    private Subst checkAndBuildSubst(SeqLike<Term.Param> telescope, SeqLike<Arg<Term>> args) {
        return Unfolder.buildSubst(telescope, args);
    }

    @Override
    @NotNull
    default public Term visitPrimCall(@NotNull CallTerm.Prim prim, P p) {
        return PrimDef.Factory.INSTANCE.unfold(prim.id(), prim, this.state());
    }

    @Override
    @NotNull
    default public Term visitHole(@NotNull CallTerm.Hole hole, P p) {
        Meta def = hole.ref();
        TyckState state = this.state();
        if (state == null) {
            return hole;
        }
        MutableMap<Meta, Term> metas = state.metas();
        if (!metas.containsKey((Object)def)) {
            return hole;
        }
        Term body = (Term)metas.get((Object)def);
        ImmutableSeq args = hole.fullArgs().map(arg -> this.visitArg((Arg<Term>)arg, p)).toImmutableSeq();
        Subst subst = this.checkAndBuildSubst((SeqLike<Term.Param>)def.fullTelescope(), (SeqLike<Arg<Term>>)args);
        return (Term)body.subst(subst).accept(this, p);
    }

    @Nullable
    default public WithPos<Term> tryUnfoldClauses(P p, boolean orderIndependent, SeqLike<Arg<Term>> args, int ulift, @NotNull ImmutableSeq<Matching> clauses) {
        return this.tryUnfoldClauses(p, orderIndependent, args, new Subst((MutableMap<Var, Term>)MutableMap.create()), ulift, clauses);
    }

    @Nullable
    default public WithPos<Term> tryUnfoldClauses(P p, boolean orderIndependent, SeqLike<Arg<Term>> args, @NotNull Subst subst, int ulift, @NotNull ImmutableSeq<Matching> clauses) {
        for (Matching matchy : clauses) {
            Result<Subst, Boolean> termSubst = PatMatcher.tryBuildSubstArgs(null, matchy.patterns(), args);
            if (termSubst.isOk()) {
                subst.add((Subst)termSubst.get());
                Term newBody = (Term)matchy.body().view().rename().subst(subst).lift(ulift).commit().accept(this, p);
                return new WithPos(matchy.sourcePos(), (Object)newBody);
            }
            if (orderIndependent || !((Boolean)termSubst.getErr()).booleanValue()) continue;
            return null;
        }
        return null;
    }

    @Override
    @NotNull
    default public Term visitAccess(@NotNull CallTerm.Access term, P p) {
        Term nevv = (Term)term.of().accept(this, p);
        Var fieldRef = term.ref();
        FieldDef fieldDef = (FieldDef)((DefVar)fieldRef).core;
        if (!(nevv instanceof IntroTerm.New)) {
            ImmutableSeq args = term.args().map(arg -> this.visitArg((Arg<Term>)arg, p));
            Subst fieldSubst = this.checkAndBuildSubst((SeqLike<Term.Param>)fieldDef.fullTelescope(), (SeqLike<Arg<Term>>)args);
            StructDef structDef = (StructDef)fieldDef.structRef.core;
            int structArgsSize = term.structArgs().size();
            for (FieldDef field : structDef.fields) {
                if (field == fieldDef) continue;
                ImmutableSeq<Term.Param> tele = field.telescope();
                CallTerm.Access access = new CallTerm.Access(nevv, field.ref, (ImmutableSeq<Arg<Term>>)args.take(structArgsSize), (ImmutableSeq<Arg<Term>>)tele.map(Term.Param::toArg));
                fieldSubst.add(field.ref, IntroTerm.Lambda.make(tele, access));
            }
            ImmutableSeq dropped = args.drop(structArgsSize);
            WithPos<Term> mischa = this.tryUnfoldClauses(p, true, (SeqLike<Arg<Term>>)dropped, fieldSubst, 0, (ImmutableSeq<Matching>)fieldDef.clauses);
            return mischa != null ? ((Term)mischa.data()).subst(fieldSubst) : new CallTerm.Access(nevv, (DefVar<FieldDef, Decl.StructField>)fieldRef, term.structArgs(), (ImmutableSeq<Arg<Term>>)dropped);
        }
        IntroTerm.New n = (IntroTerm.New)nevv;
        Subst arguments = Unfolder.buildSubst((SeqLike<Term.Param>)fieldDef.ownerTele, term.structArgs());
        Term fieldBody = (Term)term.fieldArgs().foldLeft((Object)((Term)n.params().get((Object)fieldRef)), CallTerm::make);
        return (Term)fieldBody.subst(arguments).accept(this, p);
    }

    public record Tracked(@NotNull @NotNull Set<@NotNull Var> unfolding, @NotNull @NotNull MutableSet<@NotNull Var> unfolded, @Nullable TyckState state, @NotNull PrimDef.Factory factory) implements Unfolder<Unit>
    {
        @Override
        @NotNull
        public Term visitFnCall(@NotNull CallTerm.Fn fnCall, Unit unit) {
            if (!this.unfolding.contains((Object)fnCall.ref())) {
                return fnCall;
            }
            this.unfolded.add((Object)fnCall.ref());
            return Unfolder.super.visitFnCall(fnCall, unit);
        }

        @Override
        @NotNull
        public Term visitConCall(@NotNull CallTerm.Con conCall, Unit unit) {
            if (!this.unfolding.contains((Object)conCall.ref())) {
                return conCall;
            }
            this.unfolded.add((Object)conCall.ref());
            return Unfolder.super.visitConCall(conCall, unit);
        }
    }
}

