/*
 * Decompiled with CFR 0.152.
 */
package org.aya.tyck.pat;

import java.lang.runtime.SwitchBootstraps;
import java.util.Objects;
import java.util.function.Supplier;
import java.util.function.UnaryOperator;
import kala.collection.Seq;
import kala.collection.SeqLike;
import kala.collection.SeqView;
import kala.collection.immutable.ImmutableSeq;
import kala.collection.immutable.primitive.ImmutableIntArray;
import kala.collection.immutable.primitive.ImmutableIntSeq;
import kala.collection.mutable.MutableList;
import kala.collection.mutable.MutableSeq;
import kala.value.MutableValue;
import kala.value.primitive.MutableBooleanValue;
import org.aya.generic.AyaDocile;
import org.aya.generic.Renamer;
import org.aya.normalize.Finalizer;
import org.aya.normalize.LetReplacer;
import org.aya.prettier.AyaPrettierOptions;
import org.aya.syntax.concrete.Expr;
import org.aya.syntax.concrete.Pattern;
import org.aya.syntax.core.Jdg;
import org.aya.syntax.core.def.FnClauseBody;
import org.aya.syntax.core.pat.Pat;
import org.aya.syntax.core.pat.PatToTerm;
import org.aya.syntax.core.term.AppTerm;
import org.aya.syntax.core.term.DepTypeTerm;
import org.aya.syntax.core.term.ErrorTerm;
import org.aya.syntax.core.term.FreeTerm;
import org.aya.syntax.core.term.MetaPatTerm;
import org.aya.syntax.core.term.Param;
import org.aya.syntax.core.term.Term;
import org.aya.syntax.ref.GenerateKind;
import org.aya.syntax.ref.LocalCtx;
import org.aya.syntax.ref.LocalVar;
import org.aya.tyck.ExprTycker;
import org.aya.tyck.TyckState;
import org.aya.tyck.ctx.LocalLet;
import org.aya.tyck.error.PatternProblem;
import org.aya.tyck.pat.PatBinder;
import org.aya.tyck.pat.PatClassifier;
import org.aya.tyck.pat.PatternTycker;
import org.aya.tyck.pat.iter.LambdaPusheen;
import org.aya.tyck.pat.iter.PatternIterator;
import org.aya.tyck.pat.iter.SignatureIterator;
import org.aya.tyck.tycker.Problematic;
import org.aya.tyck.tycker.Stateful;
import org.aya.util.Arg;
import org.aya.util.Pair;
import org.aya.util.Panic;
import org.aya.util.PrettierOptions;
import org.aya.util.position.SourceNode;
import org.aya.util.position.SourcePos;
import org.aya.util.position.WithPos;
import org.aya.util.reporter.Reporter;
import org.aya.util.tyck.pat.PatClass;
import org.jetbrains.annotations.Contract;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

public final class ClauseTycker
implements Problematic,
Stateful {
    @NotNull
    private final ExprTycker exprTycker;
    private final Finalizer.Zonk<ClauseTycker> zonker = new Finalizer.Zonk<ClauseTycker>(this);

    public ClauseTycker(@NotNull ExprTycker exprTycker) {
        this.exprTycker = exprTycker;
    }

    @Override
    @NotNull
    public Reporter reporter() {
        return this.exprTycker.reporter;
    }

    @Override
    @NotNull
    public TyckState state() {
        return this.exprTycker.state;
    }

    @NotNull
    public MutableSeq<LhsResult> checkAllLhs(@NotNull Supplier<SignatureIterator> sigIterFactory, @NotNull SeqView<Pattern.Clause> clauses, int userUnpiSize) {
        return MutableSeq.from((Iterable)clauses.map(c -> this.checkLhs((SignatureIterator)sigIterFactory.get(), (Pattern.Clause)c, true, userUnpiSize)));
    }

    @NotNull
    public TyckResult checkAllRhs(@NotNull ImmutableSeq<LocalVar> vars, @NotNull Seq<LhsResult> lhsResults, boolean lhsError) {
        ImmutableSeq rhsResult = lhsResults.map(x -> this.checkRhs(vars, (LhsResult)x));
        rhsResult = rhsResult.map(preclause -> new Pat.Preclause(preclause.sourcePos(), preclause.pats().map(p -> p.descentTerm(this.zonker::zonk)), preclause.bindCount(), preclause.expr()));
        return new TyckResult((ImmutableSeq<Pat.Preclause<Term>>)rhsResult, lhsError);
    }

    @NotNull
    public LhsResult checkLhs(@NotNull SignatureIterator sigIter, @NotNull Pattern.Clause clause, boolean isFn, int userUnpiSize) {
        PatternTycker tycker = this.newPatternTycker(sigIter, sigIter.elims != null);
        try (ExprTycker.SubscopedNoVar subscopedNoVar = this.exprTycker.subscope();){
            if (isFn && !clause.patterns.anyMatch(p -> ClauseTycker.hasAbsurdity((Pattern)((WithPos)p.term()).data())) && clause.expr.isEmpty()) {
                clause.hasError = true;
                this.exprTycker.fail(new PatternProblem.InvalidEmptyBody(clause));
            }
            PatternIterator patIter = new PatternIterator((ImmutableSeq<Arg<WithPos<Pattern>>>)clause.patterns, clause.expr.isDefined() ? new LambdaPusheen((WithPos<Expr>)((WithPos)clause.expr.get())) : PatternIterator.DUMMY);
            PatternTycker.TyckResult patResult = tycker.tyck(patIter, null);
            LocalCtx ctx = this.exprTycker.localCtx();
            clause.hasError |= patResult.hasError();
            patResult = ClauseTycker.inline(patResult, ctx);
            clause.patterns.forEach(it -> TermInPatInline.apply((Pattern)((WithPos)it.term()).data()));
            ctx = ctx.map((UnaryOperator)new TermInline());
            Term instRepi = sigIter.unpiBody().makePi().instTele(patResult.paramSubst().view().map(Jdg::wellTyped));
            DepTypeTerm.Unpi instUnpiParam = DepTypeTerm.unpiDBI((Term)instRepi, UnaryOperator.identity(), (int)userUnpiSize);
            ImmutableSeq missingPats = instUnpiParam.params().mapIndexed((idx, x) -> new Pat.Bind(new LocalVar("unpi" + idx, SourcePos.NONE, (GenerateKind)GenerateKind.Basic.Tyck), x.type()));
            ImmutableSeq wellTypedPats = patResult.wellTyped().appendedAll((Iterable)missingPats);
            LhsResult lhsResult = new LhsResult(ctx, instRepi, userUnpiSize, (ImmutableSeq<Pat>)wellTypedPats, clause.sourcePos, patIter.exprBody(), patResult.paramSubst(), patResult.asSubst(), patResult.hasError());
            return lhsResult;
        }
    }

    @NotNull
    private Pat.Preclause<Term> checkRhs(@NotNull ImmutableSeq<LocalVar> teleBinds, @NotNull LhsResult result) {
        try (ExprTycker.SubscopedNoVar subscopedNoVar = this.exprTycker.subscope();){
            ErrorTerm wellBody;
            WithPos<Expr> bodyExpr = result.body;
            int bindCount = 0;
            ImmutableSeq pats = result.freePats();
            if (bodyExpr == null) {
                wellBody = null;
            } else if (result.hasError) {
                wellBody = new ErrorTerm((AyaDocile)bodyExpr.data());
            } else {
                this.exprTycker.setLocalCtx(result.localCtx);
                result.dumpLocalLetTo(teleBinds, this.exprTycker);
                wellBody = this.exprTycker.inherit(bodyExpr, result.result()).wellTyped();
                this.exprTycker.solveMetas();
                wellBody = this.zonker.zonk((Term)wellBody);
                Pair patWithTypeBound = Pat.collectVariables((SeqView)result.freePats().view());
                pats = (ImmutableSeq)patWithTypeBound.component2();
                MutableList patBindTele = (MutableList)patWithTypeBound.component1();
                bindCount = patBindTele.size();
                wellBody = AppTerm.make((Term)wellBody, (SeqView)pats.view().takeLast(result.unpiParamSize).map(PatToTerm::visit));
                wellBody = wellBody.bindTele(patBindTele.view());
            }
            Pat.Preclause preclause = new Pat.Preclause(result.sourcePos, pats, bindCount, wellBody == null ? null : WithPos.dummy((Object)wellBody));
            return preclause;
        }
    }

    @NotNull
    private PatternTycker newPatternTycker(@NotNull SignatureIterator sigIter, boolean hasElim) {
        return new PatternTycker(this.exprTycker, sigIter, new LocalLet(), !hasElim, new Renamer());
    }

    private static boolean hasAbsurdity(@NotNull Pattern term) {
        return ClauseTycker.hasAbsurdity(term, MutableBooleanValue.create());
    }

    private static boolean hasAbsurdity(@NotNull Pattern term, @NotNull MutableBooleanValue b) {
        if (term == Pattern.Absurd.INSTANCE) {
            b.set(true);
        } else {
            term.forEach((sourcePos, p) -> b.set(b.get() || ClauseTycker.hasAbsurdity(p, b)));
        }
        return b.get();
    }

    @NotNull
    private static Jdg inlineTerm(@NotNull Jdg r) {
        return r.map((UnaryOperator)new TermInline());
    }

    @NotNull
    private static PatternTycker.TyckResult inline(@NotNull PatternTycker.TyckResult result, @NotNull LocalCtx ctx) {
        ImmutableSeq wellTyped = result.wellTyped().map(x -> x.inline((arg_0, arg_1) -> ((LocalCtx)ctx).put(arg_0, arg_1)).descentTerm((UnaryOperator)new TermInline()));
        ImmutableSeq paramSubst = result.paramSubst().map(ClauseTycker::inlineTerm);
        result.asSubst().let().replaceAll((localVar, t) -> ClauseTycker.inlineTerm(t));
        return new PatternTycker.TyckResult((ImmutableSeq<Pat>)wellTyped, (ImmutableSeq<Jdg>)paramSubst, result.asSubst(), result.hasError());
    }

    public record TyckResult(@NotNull ImmutableSeq<Pat.Preclause<Term>> clauses, boolean hasLhsError) {
        @NotNull
        public ImmutableSeq<WithPos<Term.Matching>> wellTyped() {
            return this.clauses.flatMap(Pat.Preclause::lift);
        }

        @Nullable
        public ImmutableIntSeq absurdPrefixCount() {
            int[] ints = new int[this.clauses.size()];
            int count = 0;
            for (int i = 0; i < this.clauses.size(); ++i) {
                Pat.Preclause clause = (Pat.Preclause)this.clauses.get(i);
                if (clause.expr() == null) {
                    // empty if block
                }
                ints[i] = ++count;
            }
            if (count == 0) {
                return null;
            }
            return ImmutableIntArray.Unsafe.wrap((int[])ints);
        }
    }

    private static final class TermInline
    implements UnaryOperator<Term> {
        private TermInline() {
        }

        @Override
        @NotNull
        public Term apply(@NotNull Term term) {
            if (term instanceof MetaPatTerm) {
                MetaPatTerm metaPat = (MetaPatTerm)term;
                boolean isEmpty = metaPat.meta().solution().isEmpty();
                if (isEmpty) {
                    throw new Panic("Unable to inline " + String.valueOf(metaPat.toDoc((PrettierOptions)AyaPrettierOptions.debug())));
                }
                return metaPat.inline((UnaryOperator)this);
            }
            return term.descent((UnaryOperator)this);
        }
    }

    public record LhsResult(@NotNull LocalCtx localCtx, @NotNull Term result, int unpiParamSize, @NotNull ImmutableSeq<Pat> freePats, @NotNull SourcePos sourcePos, @Nullable WithPos<Expr> body, @NotNull ImmutableSeq<Jdg> paramSubst, @NotNull LocalLet asSubst, boolean hasError) implements SourceNode
    {
        @Contract(mutates="param2")
        public void dumpLocalLetTo(@NotNull ImmutableSeq<LocalVar> teleBinds, @NotNull ExprTycker exprTycker) {
            teleBinds.forEachWith(this.paramSubst, (arg_0, arg_1) -> ((LocalLet)exprTycker.localLet()).put(arg_0, arg_1));
            exprTycker.setLocalLet(exprTycker.localLet().derive(this.asSubst.let()));
        }
    }

    private static final class TermInPatInline {
        private TermInPatInline() {
        }

        public static void apply(@NotNull Pattern pat) {
            MutableValue typeRef;
            Pattern pattern = pat;
            Objects.requireNonNull(pattern);
            Pattern pattern2 = pattern;
            int n = 0;
            switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{Pattern.Bind.class, Pattern.As.class}, (Pattern)pattern2, n)) {
                case 0: {
                    Pattern.Bind bind = (Pattern.Bind)pattern2;
                    MutableValue mutableValue = bind.type();
                    break;
                }
                case 1: {
                    Pattern.As as = (Pattern.As)pattern2;
                    MutableValue mutableValue = as.type();
                    break;
                }
                default: {
                    MutableValue mutableValue = typeRef = null;
                }
            }
            if (typeRef != null) {
                typeRef.update(it -> it == null ? null : it.descent((UnaryOperator)new TermInline()));
            }
            pat.forEach((sourcePos, p) -> TermInPatInline.apply(p));
        }
    }

    public record Worker(@NotNull ClauseTycker parent, @NotNull ImmutableSeq<Param> telescope, @NotNull DepTypeTerm.Unpi unpi, @NotNull ImmutableSeq<LocalVar> teleVars, @NotNull ImmutableSeq<LocalVar> elims, @NotNull ImmutableSeq<Pattern.Clause> clauses) {
        @NotNull
        public WorkerResult check(@NotNull SourcePos overallPos) {
            ImmutableSeq<PatClass.Seq<Term, Pat>> classes;
            MutableSeq<LhsResult> lhs = this.checkAllLhs();
            boolean hasError = lhs.anyMatch(LhsResult::hasError);
            if (!hasError) {
                classes = PatClassifier.classify((SeqView<ImmutableSeq<Pat>>)lhs.view().map(LhsResult::freePats), (SeqView<Param>)this.telescope.view().concat((SeqLike)this.unpi.params()), this.parent.exprTycker, overallPos);
                if (this.clauses.isNotEmpty()) {
                    MutableSeq<MutableList<PatClass.Seq<Term, Pat>>> usages = PatClassifier.firstMatchDomination(this.clauses, this.parent, classes);
                    for (int i = 0; i < usages.size(); ++i) {
                        LhsResult newLhs;
                        MutableList currentClasses;
                        if (((Pattern.Clause)this.clauses.get((int)i)).expr.isEmpty() || !(currentClasses = (MutableList)usages.get(i)).sizeEquals(1) || (newLhs = this.refinePattern((LhsResult)lhs.get(i), (PatClass.Seq<Term, Pat>)((PatClass.Seq)currentClasses.get(0)))) == null) continue;
                        lhs.set(i, (Object)newLhs);
                    }
                }
            } else {
                classes = null;
            }
            TyckResult rhs = this.parent.checkAllRhs(this.teleVars, (Seq<LhsResult>)lhs, hasError);
            FnClauseBody wellTyped = new FnClauseBody(rhs.wellTyped());
            if (classes != null) {
                ImmutableIntSeq absurds = rhs.absurdPrefixCount();
                wellTyped.classes = classes.map(cl -> cl.ignoreAbsurd(absurds));
            }
            return new WorkerResult(wellTyped, hasError);
        }

        @Nullable
        private LhsResult refinePattern(LhsResult curLhs, PatClass.Seq<Term, Pat> curCls) {
            LetReplacer lets = new PatBinder().apply(curLhs.freePats(), (ImmutableSeq<Term>)curCls.term());
            if (lets.let().let().allMatch((localVar, j) -> j.wellTyped() instanceof FreeTerm)) {
                return null;
            }
            LocalCtx sibling = Objects.requireNonNull((LocalCtx)curLhs.localCtx.parent()).derive();
            ImmutableSeq newPatterns = curCls.pat().map(pat -> pat.descentTerm((UnaryOperator)lets));
            newPatterns.forEach(pat -> pat.consumeBindings((arg_0, arg_1) -> ((LocalCtx)sibling).put(arg_0, arg_1)));
            curLhs.asSubst.let().replaceAll((localVar, t) -> t.map((UnaryOperator)lets));
            ImmutableSeq paramSubst = curLhs.paramSubst.map(jdg -> jdg.map((UnaryOperator)lets));
            lets.let().let().forEach((arg_0, arg_1) -> ((LocalLet)curLhs.asSubst).put(arg_0, arg_1));
            return new LhsResult(sibling, lets.apply(curLhs.result), curLhs.unpiParamSize, (ImmutableSeq<Pat>)newPatterns, curLhs.sourcePos, curLhs.body, (ImmutableSeq<Jdg>)paramSubst, curLhs.asSubst, curLhs.hasError);
        }

        @NotNull
        public MutableSeq<LhsResult> checkAllLhs() {
            return this.parent.checkAllLhs(() -> SignatureIterator.make(this.telescope, this.unpi, this.teleVars, this.elims), (SeqView<Pattern.Clause>)this.clauses.view(), this.unpi.params().size());
        }

        @NotNull
        public TyckResult checkNoClassify() {
            MutableSeq<LhsResult> lhsResults = this.checkAllLhs();
            return this.parent.checkAllRhs(this.teleVars, (Seq<LhsResult>)lhsResults, lhsResults.anyMatch(LhsResult::hasError));
        }
    }

    public record WorkerResult(FnClauseBody wellTyped, boolean hasLhsError) {
    }
}

