/*
 * Decompiled with CFR 0.152.
 */
package org.aya.tyck;

import java.lang.runtime.SwitchBootstraps;
import java.util.Objects;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import kala.collection.Map;
import kala.collection.SeqLike;
import kala.collection.SeqView;
import kala.collection.immutable.ImmutableMap;
import kala.collection.immutable.ImmutableSeq;
import kala.control.Either;
import kala.control.Option;
import org.aya.concrete.Expr;
import org.aya.concrete.Pattern;
import org.aya.concrete.stmt.ClassDecl;
import org.aya.concrete.stmt.Decl;
import org.aya.concrete.stmt.TeleDecl;
import org.aya.core.Matching;
import org.aya.core.def.CtorDef;
import org.aya.core.def.DataDef;
import org.aya.core.def.Def;
import org.aya.core.def.FieldDef;
import org.aya.core.def.FnDef;
import org.aya.core.def.GenericDef;
import org.aya.core.def.PrimDef;
import org.aya.core.def.StructDef;
import org.aya.core.pat.Pat;
import org.aya.core.repr.AyaShape;
import org.aya.core.term.CallTerm;
import org.aya.core.term.FormTerm;
import org.aya.core.term.Term;
import org.aya.generic.Arg;
import org.aya.generic.Modifier;
import org.aya.generic.ParamLike;
import org.aya.ref.AnyVar;
import org.aya.ref.DefVar;
import org.aya.tyck.ExprTycker;
import org.aya.tyck.env.LocalCtx;
import org.aya.tyck.error.NobodyError;
import org.aya.tyck.error.PrimError;
import org.aya.tyck.pat.Conquer;
import org.aya.tyck.pat.PatClassifier;
import org.aya.tyck.pat.PatTycker;
import org.aya.tyck.trace.Trace;
import org.aya.tyck.unify.Unifier;
import org.aya.util.Ordering;
import org.aya.util.TreeBuilder;
import org.aya.util.error.SourcePos;
import org.aya.util.reporter.Problem;
import org.aya.util.reporter.Reporter;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

public record StmtTycker(@NotNull Reporter reporter,  @Nullable Trace.Builder traceBuilder) {
    @NotNull
    public ExprTycker newTycker(@NotNull PrimDef.Factory primFactory, @NotNull AyaShape.Factory literalShapes) {
        return new ExprTycker(primFactory, literalShapes, this.reporter, this.traceBuilder);
    }

    private void tracing(@NotNull @NotNull Consumer< @NotNull Trace.Builder> consumer) {
        if (this.traceBuilder != null) {
            consumer.accept(this.traceBuilder);
        }
    }

    private <S extends Decl, D extends GenericDef> D traced(@NotNull S yeah, ExprTycker p, @NotNull BiFunction<S, ExprTycker, D> f) {
        this.tracing(builder -> builder.shift(new Trace.DeclT(yeah.ref(), yeah.sourcePos())));
        LocalCtx parent = p.localCtx;
        p.localCtx = parent.deriveMap();
        GenericDef r = (GenericDef)f.apply(yeah, p);
        this.tracing(TreeBuilder::reduce);
        p.localCtx = parent;
        return (D)r;
    }

    @NotNull
    public GenericDef tyck(@NotNull Decl decl, @NotNull ExprTycker tycker) {
        return this.traced(decl, tycker, this::doTyck);
    }

    @NotNull
    private GenericDef doTyck(@NotNull Decl predecl, @NotNull ExprTycker tycker) {
        Def.Signature signature;
        Decl.Telescopic decl;
        if (predecl instanceof Decl.Telescopic && (decl = (Decl.Telescopic)((Object)predecl)).signature() == null) {
            this.tyckHeader(predecl, tycker);
        }
        if (predecl instanceof Decl.Telescopic) {
            Decl.Telescopic decl2 = (Decl.Telescopic)((Object)predecl);
            signature = decl2.signature();
        } else {
            signature = null;
        }
        Def.Signature signature2 = signature;
        Decl decl3 = predecl;
        Objects.requireNonNull(decl3);
        Decl decl4 = decl3;
        int n = 0;
        return switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{ClassDecl.class, TeleDecl.FnDecl.class, TeleDecl.DataDecl.class, TeleDecl.PrimDecl.class, TeleDecl.StructDecl.class, TeleDecl.DataCtor.class, TeleDecl.StructField.class}, (Object)decl4, n)) {
            default -> throw new RuntimeException(null, null);
            case 0 -> {
                ClassDecl classDecl = (ClassDecl)decl4;
                throw new UnsupportedOperationException("ClassDecl is not supported yet");
            }
            case 1 -> {
                TeleDecl.FnDecl decl = (TeleDecl.FnDecl)decl4;
                if (!$assertionsDisabled && signature2 == null) {
                    throw new AssertionError();
                }
                BiFunction<Term, Either<Term, ImmutableSeq<Matching>>, FnDef> factory = FnDef.factory((resultTy, body) -> new FnDef(decl.ref, signature2.param(), (Term)resultTy, decl.modifiers, (Either<Term, ImmutableSeq<Matching>>)body));
                yield (FnDef)decl.body.fold(body -> {
                    Term nobody = tycker.inherit((Expr)body, signature2.result()).wellTyped();
                    tycker.solveMetas();
                    Term resultTy = tycker.zonk(signature2.result());
                    return (FnDef)factory.apply(resultTy, (Either<Term, ImmutableSeq<Matching>>)Either.left((Object)tycker.zonk(nobody)));
                }, clauses -> {
                    FnDef def;
                    PatTycker patTycker = new PatTycker(tycker);
                    SourcePos pos = decl.sourcePos;
                    if (decl.modifiers.contains((Object)Modifier.Overlap)) {
                        PatTycker.PatResult result = patTycker.elabClausesDirectly((ImmutableSeq<Pattern.Clause>)clauses, signature2);
                        def = (FnDef)factory.apply(result.result(), (Either<Term, ImmutableSeq<Matching>>)Either.right(result.matchings()));
                        if (patTycker.noError()) {
                            this.ensureConfluent(tycker, signature2, result, pos, true);
                        }
                    } else {
                        PatTycker.PatResult result = patTycker.elabClausesClassified((ImmutableSeq<Pattern.Clause>)clauses, signature2, pos);
                        def = (FnDef)factory.apply(result.result(), (Either<Term, ImmutableSeq<Matching>>)Either.right(result.matchings()));
                        if (patTycker.noError()) {
                            Conquer.against(result.matchings(), true, tycker, pos, signature2);
                        }
                    }
                    return def;
                });
            }
            case 2 -> {
                TeleDecl.DataDecl decl = (TeleDecl.DataDecl)decl4;
                if (!$assertionsDisabled && signature2 == null) {
                    throw new AssertionError();
                }
                ImmutableSeq body = decl.body.map(clause -> (CtorDef)this.tyck((Decl)clause, tycker));
                yield new DataDef(decl.ref, signature2.param(), decl.ulift, (ImmutableSeq<CtorDef>)body);
            }
            case 3 -> {
                TeleDecl.PrimDecl decl = (TeleDecl.PrimDecl)decl4;
                yield (PrimDef)decl.ref.core;
            }
            case 4 -> {
                TeleDecl.StructDecl decl = (TeleDecl.StructDecl)decl4;
                if (!$assertionsDisabled && signature2 == null) {
                    throw new AssertionError();
                }
                ImmutableSeq body = decl.fields.map(field -> (FieldDef)this.tyck((Decl)field, tycker));
                yield new StructDef(decl.ref, signature2.param(), decl.ulift, (ImmutableSeq<FieldDef>)body);
            }
            case 5 -> {
                TeleDecl.DataCtor ctor = (TeleDecl.DataCtor)decl4;
                if (ctor.ref.core != null) {
                    yield (CtorDef)ctor.ref.core;
                }
                if (!($assertionsDisabled || signature2 == ctor.signature && signature2 != null)) {
                    throw new AssertionError();
                }
                DefVar<DataDef, TeleDecl.DataDecl> dataRef = ctor.dataRef;
                TeleDecl.DataDecl dataConcrete = (TeleDecl.DataDecl)dataRef.concrete;
                Def.Signature dataSig = dataConcrete.signature;
                if (!$assertionsDisabled && dataSig == null) {
                    throw new AssertionError();
                }
                CallTerm.Data dataCall = (CallTerm.Data)signature2.result();
                ImmutableSeq<Term.Param> tele = signature2.param();
                PatTycker patTycker = ctor.yetTycker;
                ImmutableSeq<Pat> pat = ctor.yetTyckedPat;
                if (!($assertionsDisabled || patTycker != null && pat != null)) {
                    throw new AssertionError();
                }
                if (!$assertionsDisabled && tycker != patTycker.exprTycker) {
                    throw new AssertionError();
                }
                if (pat.isNotEmpty()) {
                    dataCall = (CallTerm.Data)dataCall.subst((Map<AnyVar, ? extends Term>)ImmutableMap.from((Iterable)dataSig.param().view().map(Term.Param::ref).zip((Iterable)pat.view().map(Pat::toTerm))));
                }
                PatTycker.PatResult elabClauses = patTycker.elabClausesDirectly(ctor.clauses, signature2);
                CtorDef elaborated = new CtorDef(dataRef, ctor.ref, pat, ctor.patternTele, tele, elabClauses.matchings(), dataCall, ctor.coerce);
                dataConcrete.checkedBody.append((Object)elaborated);
                if (patTycker.noError()) {
                    this.ensureConfluent(tycker, signature2, elabClauses, ctor.sourcePos, false);
                }
                yield elaborated;
            }
            case 6 -> {
                TeleDecl.StructField field = (TeleDecl.StructField)decl4;
                if (field.ref.core != null) {
                    yield (FieldDef)field.ref.core;
                }
                if (!($assertionsDisabled || signature2 == field.signature && signature2 != null)) {
                    throw new AssertionError();
                }
                DefVar<StructDef, TeleDecl.StructDecl> structRef = field.structRef;
                Def.Signature structSig = ((TeleDecl.StructDecl)structRef.concrete).signature;
                if (!$assertionsDisabled && structSig == null) {
                    throw new AssertionError();
                }
                Term result = signature2.result();
                Option body = field.body.map(e -> tycker.inherit((Expr)e, result).wellTyped());
                yield new FieldDef(structRef, field.ref, structSig.param(), signature2.param(), result, (Option<Term>)body, field.coerce);
            }
        };
    }

    @NotNull
    public FnDef simpleFn(@NotNull ExprTycker tycker, TeleDecl.FnDecl fn) {
        return this.traced(fn, tycker, (o, w) -> this.doSimpleFn(tycker, fn));
    }

    @NotNull
    private FnDef doSimpleFn(@NotNull ExprTycker tycker, TeleDecl.FnDecl fn) {
        ImmutableSeq<TeleResult> okTele = this.checkTele(tycker, (ImmutableSeq<Expr.Param>)fn.telescope, null);
        Term preresult = tycker.synthesize(fn.result).wellTyped();
        Expr bodyExpr = (Expr)fn.body.getLeftValue();
        Term prebody = tycker.inherit(bodyExpr, preresult).wellTyped();
        tycker.solveMetas();
        Term result = tycker.zonk(preresult);
        ImmutableSeq<Term.Param> tele = this.zonkTele(tycker, okTele);
        fn.signature = new Def.Signature(tele, result);
        Term body = tycker.zonk(prebody);
        return new FnDef(fn.ref, tele, result, fn.modifiers, (Either<Term, ImmutableSeq<Matching>>)Either.left((Object)body));
    }

    public void tyckHeader(@NotNull Decl decl, @NotNull ExprTycker tycker) {
        this.tracing(builder -> builder.shift(new Trace.LabelT(decl.sourcePos(), "telescope of " + decl.ref().name())));
        Decl decl2 = decl;
        Objects.requireNonNull(decl2);
        Decl decl3 = decl2;
        int n = 0;
        switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{ClassDecl.class, TeleDecl.FnDecl.class, TeleDecl.DataDecl.class, TeleDecl.StructDecl.class, TeleDecl.PrimDecl.class, TeleDecl.DataCtor.class, TeleDecl.StructField.class}, (Object)decl3, n)) {
            default: {
                throw new RuntimeException(null, null);
            }
            case 0: {
                ClassDecl classDecl = (ClassDecl)decl3;
                throw new UnsupportedOperationException("ClassDecl is not supported yet");
            }
            case 1: {
                TeleDecl.FnDecl fn = (TeleDecl.FnDecl)decl3;
                ImmutableSeq<Term.Param> resultTele = this.tele(tycker, (ImmutableSeq<Expr.Param>)fn.telescope, null);
                Term resultRes = tycker.synthesize(fn.result).wellTyped().freezeHoles(tycker.state);
                fn.signature = new Def.Signature(resultTele, resultRes);
                if (!resultTele.isEmpty() || !fn.body.isRight() || !((ImmutableSeq)fn.body.getRightValue()).isEmpty()) break;
                this.reporter.report((Problem)new NobodyError(decl.sourcePos(), fn.ref));
                break;
            }
            case 2: {
                FormTerm.Sort resultTy;
                TeleDecl.DataDecl data = (TeleDecl.DataDecl)decl3;
                SourcePos pos = data.sourcePos;
                ImmutableSeq<Term.Param> tele = this.tele(tycker, (ImmutableSeq<Expr.Param>)data.telescope, null);
                data.ulift = resultTy = this.resultTy(tycker, data);
                data.signature = new Def.Signature(tele, resultTy);
                break;
            }
            case 3: {
                TeleDecl.StructDecl struct = (TeleDecl.StructDecl)decl3;
                SourcePos pos = struct.sourcePos;
                ImmutableSeq<Term.Param> tele = this.tele(tycker, (ImmutableSeq<Expr.Param>)struct.telescope, null);
                FormTerm.Sort result = this.resultTy(tycker, struct);
                struct.signature = new Def.Signature(tele, result);
                struct.ulift = result;
                break;
            }
            case 4: {
                TeleDecl.PrimDecl prim = (TeleDecl.PrimDecl)decl3;
                assert (tycker.localCtx.isEmpty());
                PrimDef core = (PrimDef)prim.ref.core;
                ImmutableSeq<Term.Param> tele = this.tele(tycker, (ImmutableSeq<Expr.Param>)prim.telescope, null);
                if (tele.isNotEmpty()) {
                    if (prim.result instanceof Expr.ErrorExpr) {
                        this.reporter.report((Problem)new PrimError.NoResultType(prim));
                        return;
                    }
                    Term result = tycker.synthesize(prim.result).wellTyped();
                    tycker.unifyTyReported(FormTerm.Pi.make(tele, result), FormTerm.Pi.make((SeqLike<Term.Param>)core.telescope, core.result), prim.result);
                    prim.signature = new Def.Signature(tele, result);
                } else if (!(prim.result instanceof Expr.ErrorExpr)) {
                    Term result = tycker.synthesize(prim.result).wellTyped();
                    tycker.unifyTyReported(result, core.result, prim.result);
                } else {
                    prim.signature = new Def.Signature((ImmutableSeq<Term.Param>)core.telescope, core.result);
                }
                tycker.solveMetas();
                break;
            }
            case 5: {
                TeleDecl.DataCtor ctor = (TeleDecl.DataCtor)decl3;
                if (ctor.signature != null) {
                    return;
                }
                DefVar<DataDef, TeleDecl.DataDecl> dataRef = ctor.dataRef;
                TeleDecl.DataDecl dataConcrete = (TeleDecl.DataDecl)dataRef.concrete;
                Def.Signature dataSig = dataConcrete.signature;
                assert (dataSig != null);
                ImmutableSeq dataArgs = dataSig.param().map(Term.Param::toArg);
                CallTerm.Data dataCall = new CallTerm.Data(dataRef, 0, (ImmutableSeq<Arg<Term>>)dataArgs);
                Def.Signature sig = new Def.Signature(dataSig.param(), dataCall);
                PatTycker patTycker = new PatTycker(tycker);
                ImmutableSeq pat = ctor.patterns.isNotEmpty() ? ((SeqView)patTycker.visitPatterns((Def.Signature)sig, (SeqView<Pattern>)ctor.patterns.view())._1).toImmutableSeq() : ImmutableSeq.empty();
                ImmutableSeq<Term.Param> tele = this.tele(tycker, ctor.telescope, dataConcrete.ulift);
                ctor.signature = new Def.Signature(tele, dataCall);
                ctor.yetTycker = patTycker;
                ctor.yetTyckedPat = pat;
                ctor.patternTele = pat.isEmpty() ? dataSig.param().map(Term.Param::implicitify) : Pat.extractTele((SeqLike<Pat>)pat);
                break;
            }
            case 6: {
                TeleDecl.StructField field = (TeleDecl.StructField)decl3;
                if (field.signature != null) {
                    return;
                }
                DefVar<StructDef, TeleDecl.StructDecl> structRef = field.structRef;
                Def.Signature structSig = ((TeleDecl.StructDecl)structRef.concrete).signature;
                assert (structSig != null);
                FormTerm.Sort structLvl = ((TeleDecl.StructDecl)structRef.concrete).ulift;
                ImmutableSeq<Term.Param> tele = this.tele(tycker, field.telescope, structLvl);
                Term result = tycker.zonk(tycker.inherit(field.result, structLvl)).wellTyped();
                field.signature = new Def.Signature(tele, result);
            }
        }
        this.tracing(TreeBuilder::reduce);
    }

    private FormTerm.Sort resultTy(@NotNull ExprTycker tycker, TeleDecl data) {
        FormTerm.Sort ret = FormTerm.Type.ZERO;
        if (!(data.result instanceof Expr.HoleExpr)) {
            ExprTycker.TyResult result = tycker.ty(data.result);
            ret = (FormTerm.Sort)tycker.zonk(result.wellTyped());
        }
        return ret;
    }

    private void ensureConfluent(ExprTycker tycker, Def.Signature signature, PatTycker.PatResult elabClauses, SourcePos pos, boolean coverage) {
        if (!coverage && elabClauses.matchings().isEmpty()) {
            return;
        }
        this.tracing(builder -> builder.shift(new Trace.LabelT(pos, "confluence check")));
        PatClassifier.confluence(elabClauses, tycker, pos, PatClassifier.classify(elabClauses.clauses(), signature.param(), tycker, pos, coverage));
        Conquer.against(elabClauses.matchings(), true, tycker, pos, signature);
        tycker.solveMetas();
        this.tracing(TreeBuilder::reduce);
    }

    @NotNull
    private ImmutableSeq<Term.Param> tele(@NotNull ExprTycker tycker, @NotNull ImmutableSeq<Expr.Param> tele, @Nullable FormTerm.Sort sort) {
        ImmutableSeq<TeleResult> okTele = this.checkTele(tycker, tele, sort);
        tycker.solveMetas();
        return this.zonkTele(tycker, okTele);
    }

    @NotNull
    private ExprTycker.Result checkTele(@NotNull ExprTycker exprTycker, @NotNull Expr tele, @NotNull FormTerm.Sort sort) {
        ExprTycker.SortResult result = exprTycker.sort(tele);
        Unifier unifier = exprTycker.unifier(tele.sourcePos(), Ordering.Lt);
        FormTerm.Sort sort2 = result.type();
        Objects.requireNonNull(sort2);
        FormTerm.Sort sort3 = sort2;
        int n = 0;
        switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{FormTerm.Type.class, FormTerm.Set.class, FormTerm.Prop.class, FormTerm.ISet.class}, (Object)sort3, n)) {
            default: {
                throw new RuntimeException(null, null);
            }
            case 0: {
                FormTerm.Type ty = (FormTerm.Type)sort3;
                unifier.compareSort(ty, sort);
                break;
            }
            case 1: {
                FormTerm.Set ty = (FormTerm.Set)sort3;
                unifier.compareSort(ty, sort);
                break;
            }
            case 2: {
                FormTerm.Prop ty = (FormTerm.Prop)sort3;
                if (sort instanceof FormTerm.Type) break;
                unifier.compareSort(ty, sort);
                break;
            }
            case 3: {
                FormTerm.ISet ty = (FormTerm.ISet)sort3;
                if (sort instanceof FormTerm.Type || sort instanceof FormTerm.Set) break;
                unifier.compareSort(ty, sort);
            }
        }
        return result;
    }

    @NotNull
    private ImmutableSeq<TeleResult> checkTele(@NotNull ExprTycker exprTycker, @NotNull ImmutableSeq<Expr.Param> tele, @Nullable FormTerm.Sort sort) {
        return tele.map(param -> {
            Term paramTyped = (sort != null ? this.checkTele(exprTycker, param.type(), sort) : exprTycker.synthesize(param.type())).wellTyped();
            Term.Param newParam = new Term.Param((ParamLike<?>)param, paramTyped);
            exprTycker.localCtx.put(newParam);
            return new TeleResult(newParam, param.sourcePos());
        });
    }

    @NotNull
    private ImmutableSeq<Term.Param> zonkTele(@NotNull ExprTycker exprTycker, ImmutableSeq<TeleResult> okTele) {
        return okTele.map(tt -> {
            Term.Param rawParam = tt.param;
            Term.Param param = new Term.Param(rawParam, exprTycker.zonk(rawParam.type()));
            exprTycker.localCtx.put(param);
            return param;
        });
    }

    private record TeleResult(@NotNull Term.Param param, @NotNull SourcePos pos) {
    }
}

