/*
 * Decompiled with CFR 0.152.
 */
package org.aya.tyck.unify;

import kala.collection.Seq;
import kala.collection.immutable.ImmutableSeq;
import kala.collection.mutable.MutableArrayList;
import kala.control.Option;
import kala.tuple.Tuple2;
import org.aya.core.Meta;
import org.aya.core.ops.Eta;
import org.aya.core.term.ErrorTerm;
import org.aya.core.term.MetaTerm;
import org.aya.core.term.PiTerm;
import org.aya.core.term.RefTerm;
import org.aya.core.term.Term;
import org.aya.core.visitor.DeltaExpander;
import org.aya.core.visitor.Subst;
import org.aya.core.visitor.VarConsumer;
import org.aya.generic.util.InternalException;
import org.aya.generic.util.NormalizeMode;
import org.aya.ref.AnyVar;
import org.aya.ref.LocalVar;
import org.aya.tyck.TyckState;
import org.aya.tyck.env.LocalCtx;
import org.aya.tyck.env.MapLocalCtx;
import org.aya.tyck.error.HoleProblem;
import org.aya.tyck.trace.Trace;
import org.aya.tyck.unify.TermComparator;
import org.aya.util.Arg;
import org.aya.util.Ordering;
import org.aya.util.error.SourcePos;
import org.aya.util.reporter.Problem;
import org.aya.util.reporter.Reporter;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

public final class Unifier
extends TermComparator {
    final boolean allowVague;
    final boolean allowConfused;
    @NotNull
    private final Eta uneta;

    public Unifier(@NotNull Ordering cmp, @NotNull Reporter reporter, boolean allowVague, boolean allowConfused, @Nullable Trace.Builder traceBuilder, @NotNull TyckState state, @NotNull SourcePos pos, @NotNull LocalCtx ctx) {
        super(traceBuilder, state, reporter, pos, cmp, ctx);
        this.allowVague = allowVague;
        this.allowConfused = allowConfused;
        this.uneta = new Eta(ctx);
    }

    @NotNull
    private TyckState.Eqn createEqn(@NotNull Term lhs, @NotNull Term rhs, TermComparator.Sub lr, TermComparator.Sub rl) {
        MapLocalCtx local = new MapLocalCtx();
        this.ctx.forward(local, lhs, this.state);
        this.ctx.forward(local, rhs, this.state);
        return new TyckState.Eqn(lhs, rhs, this.cmp, this.pos, local, lr.clone(), rl.clone());
    }

    @Nullable
    private Seq<LocalVar> invertSpine(Subst subst, @NotNull MetaTerm lhs, @NotNull Meta meta) {
        MutableArrayList overlap = MutableArrayList.create();
        for (Tuple2 arg : lhs.args().zipView(meta.telescope)) {
            Term term = this.uneta.uneta((Term)((Arg)arg._1).term());
            if (term instanceof RefTerm) {
                RefTerm ref = (RefTerm)term;
                if (overlap.contains((Object)ref.var())) continue;
                if (subst.map().containsKey((Object)ref.var())) {
                    overlap.append((Object)ref.var());
                    subst.map().remove((Object)ref.var());
                }
                subst.add(ref.var(), ((Term.Param)arg._2).toTerm());
                continue;
            }
            return null;
        }
        return overlap;
    }

    @Override
    @Nullable
    protected Term solveMeta(@NotNull Term preRhs, TermComparator.Sub lr, TermComparator.Sub rl, @NotNull MetaTerm lhs) {
        Subst subst;
        Seq<LocalVar> overlap;
        Meta meta = lhs.ref();
        Option<Term> sameMeta = this.sameMeta(lr, rl, lhs, meta, preRhs);
        if (sameMeta.isDefined()) {
            return (Term)sameMeta.get();
        }
        Term resultTy = preRhs.computeType(this.state, this.ctx);
        if (meta.result != null) {
            this.compare(resultTy, meta.result, rl, lr, null);
        }
        if ((overlap = this.invertSpine(subst = DeltaExpander.buildSubst(meta.contextTele, lhs.contextArgs()), lhs, meta)) == null) {
            this.reporter.report((Problem)new HoleProblem.BadSpineError(lhs, this.pos));
            return null;
        }
        if (!this.allowVague && overlap.anyMatch(var -> preRhs.findUsages((AnyVar)var) > 0)) {
            this.state.addEqn(this.createEqn(lhs, preRhs, lr, rl));
            return resultTy;
        }
        rl.map().forEach(subst::add);
        assert (!this.state.metas().containsKey((Object)meta));
        Term solved = preRhs.freezeHoles(this.state).subst(subst);
        sameMeta = this.sameMeta(lr, rl, lhs, meta, solved);
        if (sameMeta.isDefined()) {
            return (Term)sameMeta.get();
        }
        ImmutableSeq allowedVars = meta.fullTelescope().map(Term.Param::ref).toImmutableSeq();
        VarConsumer.ScopeChecker scopeCheck = solved.scopeCheck((ImmutableSeq<LocalVar>)allowedVars);
        if (scopeCheck.invalid.isNotEmpty()) {
            solved = solved.normalize(this.state, NormalizeMode.NF);
            scopeCheck = solved.scopeCheck((ImmutableSeq<LocalVar>)allowedVars);
        }
        if (scopeCheck.invalid.isNotEmpty()) {
            this.reporter.report((Problem)new HoleProblem.BadlyScopedError(lhs, solved, (Seq<LocalVar>)scopeCheck.invalid, this.pos));
            return new ErrorTerm(solved);
        }
        if (scopeCheck.confused.isNotEmpty()) {
            if (this.allowConfused) {
                this.state.addEqn(this.createEqn(lhs, solved, lr, rl));
            } else {
                this.reporter.report((Problem)new HoleProblem.BadlyScopedError(lhs, solved, (Seq<LocalVar>)scopeCheck.confused, this.pos));
                return new ErrorTerm(solved);
            }
        }
        if (!meta.solve(this.state, solved)) {
            this.reporter.report((Problem)new HoleProblem.RecursionError(lhs, solved, this.pos));
            return new ErrorTerm(solved);
        }
        this.tracing(builder -> builder.append(new Trace.LabelT(this.pos, "Hole solved!")));
        return resultTy;
    }

    private @NotNull Option<@Nullable Term> sameMeta(TermComparator.Sub lr, TermComparator.Sub rl, @NotNull MetaTerm lhs, Meta meta, Term preRhs) {
        MetaTerm rcall;
        if (!(preRhs instanceof MetaTerm) || meta != (rcall = (MetaTerm)preRhs).ref()) {
            return Option.none();
        }
        if (meta.result == null) {
            return Option.some(null);
        }
        Term holeTy = PiTerm.make(meta.telescope, meta.result);
        for (Tuple2 arg : lhs.args().zipView(rcall.args())) {
            if (!(holeTy instanceof PiTerm)) {
                throw new InternalException("meta arg size larger than param size. this should not happen");
            }
            PiTerm holePi = (PiTerm)holeTy;
            if (!this.compare((Term)((Arg)arg._1).term(), (Term)((Arg)arg._2).term(), lr, rl, holePi.param().type())) {
                return Option.some(null);
            }
            holeTy = holePi.substBody((Term)((Arg)arg._1).term());
        }
        return Option.some((Object)holeTy);
    }

    public void checkEqn(@NotNull TyckState.Eqn eqn) {
        this.compareUntyped(eqn.lhs().normalize(this.state, NormalizeMode.WHNF), eqn.rhs().normalize(this.state, NormalizeMode.WHNF), eqn.lr(), eqn.rl());
    }
}

