/*
 * Decompiled with CFR 0.152.
 */
package org.aya.tyck;

import java.lang.runtime.SwitchBootstraps;
import java.util.Objects;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import kala.collection.Map;
import kala.collection.SeqLike;
import kala.collection.SeqView;
import kala.collection.immutable.ImmutableMap;
import kala.collection.immutable.ImmutableSeq;
import kala.control.Either;
import kala.control.Option;
import org.aya.concrete.Expr;
import org.aya.concrete.Pattern;
import org.aya.concrete.stmt.Decl;
import org.aya.concrete.stmt.Signatured;
import org.aya.core.Matching;
import org.aya.core.def.CtorDef;
import org.aya.core.def.DataDef;
import org.aya.core.def.Def;
import org.aya.core.def.FieldDef;
import org.aya.core.def.FnDef;
import org.aya.core.def.PrimDef;
import org.aya.core.def.StructDef;
import org.aya.core.pat.Pat;
import org.aya.core.term.CallTerm;
import org.aya.core.term.FormTerm;
import org.aya.core.term.Term;
import org.aya.generic.Arg;
import org.aya.generic.Modifier;
import org.aya.generic.ParamLike;
import org.aya.ref.DefVar;
import org.aya.ref.Var;
import org.aya.tyck.ExprTycker;
import org.aya.tyck.env.LocalCtx;
import org.aya.tyck.error.PrimProblem;
import org.aya.tyck.pat.Conquer;
import org.aya.tyck.pat.PatClassifier;
import org.aya.tyck.pat.PatTycker;
import org.aya.tyck.trace.Trace;
import org.aya.util.TreeBuilder;
import org.aya.util.error.SourcePos;
import org.aya.util.reporter.Problem;
import org.aya.util.reporter.Reporter;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

public record StmtTycker(@NotNull Reporter reporter,  @Nullable Trace.Builder traceBuilder) {
    @NotNull
    public ExprTycker newTycker() {
        return new ExprTycker(this.reporter, this.traceBuilder);
    }

    private void tracing(@NotNull @NotNull Consumer< @NotNull Trace.Builder> consumer) {
        if (this.traceBuilder != null) {
            consumer.accept(this.traceBuilder);
        }
    }

    private <S extends Signatured, D extends Def> D traced(@NotNull S yeah, ExprTycker p, @NotNull BiFunction<S, ExprTycker, D> f) {
        this.tracing(builder -> builder.shift(new Trace.DeclT(yeah.ref(), yeah.sourcePos)));
        LocalCtx parent = p.localCtx;
        p.localCtx = parent.deriveMap();
        Def r = (Def)f.apply(yeah, p);
        this.tracing(TreeBuilder::reduce);
        p.localCtx = parent;
        return (D)r;
    }

    @NotNull
    public Def tyck(@NotNull Decl decl, @NotNull ExprTycker tycker) {
        return this.traced(decl, tycker, this::doTyck);
    }

    @NotNull
    private Def doTyck(@NotNull Decl predecl, @NotNull ExprTycker tycker) {
        if (predecl.signature == null) {
            this.tyckHeader(predecl, tycker);
        }
        Def.Signature signature = predecl.signature;
        Decl decl = predecl;
        Objects.requireNonNull(decl);
        Decl decl2 = decl;
        int n = 0;
        return switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{Decl.FnDecl.class, Decl.DataDecl.class, Decl.PrimDecl.class, Decl.StructDecl.class}, (Object)decl2, n)) {
            default -> throw new IncompatibleClassChangeError();
            case 0 -> {
                Decl.FnDecl decl = (Decl.FnDecl)decl2;
                if (!$assertionsDisabled && signature == null) {
                    throw new AssertionError();
                }
                BiFunction<Term, Either<Term, ImmutableSeq<Matching>>, FnDef> factory = FnDef.factory((resultTy, body) -> new FnDef(decl.ref, signature.param(), (Term)resultTy, decl.modifiers, (Either<Term, ImmutableSeq<Matching>>)body));
                yield (FnDef)decl.body.fold(body -> {
                    Term nobody = tycker.inherit((Expr)body, signature.result()).wellTyped();
                    tycker.solveMetas();
                    Term resultTy = tycker.zonk(signature.result());
                    return (FnDef)factory.apply(resultTy, (Either<Term, ImmutableSeq<Matching>>)Either.left((Object)tycker.zonk(nobody)));
                }, clauses -> {
                    FnDef def;
                    PatTycker patTycker = new PatTycker(tycker);
                    SourcePos pos = decl.sourcePos;
                    if (decl.modifiers.contains((Object)Modifier.Overlap)) {
                        PatTycker.PatResult result = patTycker.elabClausesDirectly((ImmutableSeq<Pattern.Clause>)clauses, signature);
                        def = (FnDef)factory.apply(result.result(), (Either<Term, ImmutableSeq<Matching>>)Either.right(result.matchings()));
                        if (patTycker.noError()) {
                            this.ensureConfluent(tycker, signature, result, pos, true);
                        }
                    } else {
                        PatTycker.PatResult result = patTycker.elabClausesClassified((ImmutableSeq<Pattern.Clause>)clauses, signature, decl.result.sourcePos(), pos);
                        def = (FnDef)factory.apply(result.result(), (Either<Term, ImmutableSeq<Matching>>)Either.right(result.matchings()));
                        if (patTycker.noError()) {
                            Conquer.against(result.matchings(), true, tycker, pos, signature);
                        }
                    }
                    return def;
                });
            }
            case 1 -> {
                Decl.DataDecl decl = (Decl.DataDecl)decl2;
                if (!$assertionsDisabled && signature == null) {
                    throw new AssertionError();
                }
                ImmutableSeq body = decl.body.map(clause -> this.traced((Signatured)clause, tycker, this::tyck));
                yield new DataDef(decl.ref, signature.param(), decl.ulift, (ImmutableSeq<CtorDef>)body);
            }
            case 2 -> {
                Decl.PrimDecl decl = (Decl.PrimDecl)decl2;
                yield (PrimDef)decl.ref.core;
            }
            case 3 -> {
                Decl.StructDecl decl = (Decl.StructDecl)decl2;
                if (!$assertionsDisabled && signature == null) {
                    throw new AssertionError();
                }
                ImmutableSeq body = decl.fields.map(field -> this.traced((Signatured)field, tycker, this::tyck));
                yield new StructDef(decl.ref, signature.param(), decl.ulift, (ImmutableSeq<FieldDef>)body);
            }
        };
    }

    @NotNull
    public FnDef simpleFn(@NotNull ExprTycker tycker, Decl.FnDecl fn) {
        return this.traced(fn, tycker, (o, w) -> this.doSimpleFn(tycker, fn));
    }

    @NotNull
    private FnDef doSimpleFn(@NotNull ExprTycker tycker, Decl.FnDecl fn) {
        ImmutableSeq<TeleResult> okTele = this.checkTele(tycker, (ImmutableSeq<Expr.Param>)fn.telescope, -1);
        Term preresult = tycker.synthesize(fn.result).wellTyped();
        Expr bodyExpr = (Expr)fn.body.getLeftValue();
        Term prebody = tycker.inherit(bodyExpr, preresult).wellTyped();
        tycker.solveMetas();
        Term result = tycker.zonk(preresult);
        ImmutableSeq<Term.Param> tele = this.zonkTele(tycker, okTele);
        fn.signature = new Def.Signature(tele, result);
        Term body = tycker.zonk(prebody);
        return new FnDef(fn.ref, tele, result, fn.modifiers, (Either<Term, ImmutableSeq<Matching>>)Either.left((Object)body));
    }

    public void tyckHeader(@NotNull Decl decl, @NotNull ExprTycker tycker) {
        this.tracing(builder -> builder.shift(new Trace.LabelT(decl.sourcePos, "telescope")));
        Decl decl2 = decl;
        Objects.requireNonNull(decl2);
        Decl decl3 = decl2;
        int n = 0;
        switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{Decl.FnDecl.class, Decl.DataDecl.class, Decl.StructDecl.class, Decl.PrimDecl.class}, (Object)decl3, n)) {
            default: {
                throw new IncompatibleClassChangeError();
            }
            case 0: {
                Decl.FnDecl fn = (Decl.FnDecl)decl3;
                ImmutableSeq<Term.Param> resultTele = this.tele(tycker, (ImmutableSeq<Expr.Param>)fn.telescope, -1);
                Term resultRes = tycker.synthesize(fn.result).wellTyped().freezeHoles(tycker.state);
                fn.signature = new Def.Signature(resultTele, resultRes);
                break;
            }
            case 1: {
                Decl.DataDecl data = (Decl.DataDecl)decl3;
                SourcePos pos = data.sourcePos;
                ImmutableSeq<Term.Param> tele = this.tele(tycker, (ImmutableSeq<Expr.Param>)data.telescope, -1);
                FormTerm.Univ result = data.result instanceof Expr.HoleExpr ? FormTerm.Univ.ZERO : tycker.zonk(tycker.synthesize(data.result)).wellTyped();
                data.signature = new Def.Signature(tele, result);
                data.ulift = tycker.ensureUniv(decl.result, result);
                break;
            }
            case 2: {
                Decl.StructDecl struct = (Decl.StructDecl)decl3;
                SourcePos pos = struct.sourcePos;
                ImmutableSeq<Term.Param> tele = this.tele(tycker, (ImmutableSeq<Expr.Param>)struct.telescope, -1);
                Term result = tycker.zonk(tycker.synthesize(struct.result)).wellTyped();
                struct.signature = new Def.Signature(tele, result);
                struct.ulift = tycker.ensureUniv(decl.result, result);
                break;
            }
            case 3: {
                Decl.PrimDecl prim = (Decl.PrimDecl)decl3;
                assert (tycker.localCtx.isEmpty());
                PrimDef core = (PrimDef)prim.ref.core;
                ImmutableSeq<Term.Param> tele = this.tele(tycker, (ImmutableSeq<Expr.Param>)prim.telescope, -1);
                if (tele.isNotEmpty()) {
                    if (prim.result instanceof Expr.ErrorExpr) {
                        this.reporter.report((Problem)new PrimProblem.NoResultTypeError(prim));
                        return;
                    }
                    Term result = tycker.synthesize(prim.result).wellTyped();
                    tycker.unifyTyReported(FormTerm.Pi.make(tele, result), FormTerm.Pi.make((SeqLike<Term.Param>)core.telescope, core.result), prim.result);
                    prim.signature = new Def.Signature(tele, result);
                } else if (!(prim.result instanceof Expr.ErrorExpr)) {
                    Term result = tycker.synthesize(prim.result).wellTyped();
                    tycker.unifyTyReported(result, core.result, prim.result);
                } else {
                    prim.signature = new Def.Signature((ImmutableSeq<Term.Param>)core.telescope, core.result);
                }
                tycker.solveMetas();
            }
        }
        this.tracing(TreeBuilder::reduce);
    }

    public void tyckHeader(@NotNull Decl.DataCtor ctor, ExprTycker tycker) {
        if (ctor.signature != null) {
            return;
        }
        DefVar<DataDef, Decl.DataDecl> dataRef = ctor.dataRef;
        Decl.DataDecl dataConcrete = (Decl.DataDecl)dataRef.concrete;
        Def.Signature dataSig = dataConcrete.signature;
        assert (dataSig != null);
        ImmutableSeq dataArgs = dataSig.param().map(Term.Param::toArg);
        CallTerm.Data dataCall = new CallTerm.Data(dataRef, 0, (ImmutableSeq<Arg<Term>>)dataArgs);
        Def.Signature sig = new Def.Signature(dataSig.param(), dataCall);
        PatTycker patTycker = new PatTycker(tycker);
        ImmutableSeq pat = ctor.patterns.isNotEmpty() ? (ImmutableSeq)patTycker.visitPatterns((Def.Signature)sig, (SeqView<Pattern>)ctor.patterns.view())._1 : ImmutableSeq.empty();
        ImmutableSeq<Term.Param> tele = this.tele(tycker, (ImmutableSeq<Expr.Param>)ctor.telescope, dataConcrete.ulift);
        ctor.signature = new Def.Signature(tele, dataCall);
        ctor.yetTycker = patTycker;
        ctor.yetTyckedPat = pat;
        ctor.patternTele = pat.isEmpty() ? dataSig.param().map(Term.Param::implicitify) : Pat.extractTele((SeqLike<Pat>)pat);
    }

    @NotNull
    public CtorDef tyck(@NotNull Decl.DataCtor ctor, ExprTycker tycker) {
        if (ctor.ref.core != null) {
            return (CtorDef)ctor.ref.core;
        }
        DefVar<DataDef, Decl.DataDecl> dataRef = ctor.dataRef;
        Decl.DataDecl dataConcrete = (Decl.DataDecl)dataRef.concrete;
        Def.Signature dataSig = dataConcrete.signature;
        assert (dataSig != null);
        if (ctor.signature == null) {
            this.tyckHeader(ctor, tycker);
        }
        Def.Signature signature = ctor.signature;
        CallTerm.Data dataCall = (CallTerm.Data)signature.result();
        ImmutableSeq<Term.Param> tele = signature.param();
        PatTycker patTycker = ctor.yetTycker;
        ImmutableSeq<Pat> pat = ctor.yetTyckedPat;
        assert (patTycker != null && pat != null);
        assert (tycker == patTycker.exprTycker);
        if (pat.isNotEmpty()) {
            dataCall = (CallTerm.Data)dataCall.subst((Map<Var, ? extends Term>)ImmutableMap.from((Iterable)dataSig.param().view().map(Term.Param::ref).zip((Iterable)pat.view().map(Pat::toTerm))));
        }
        PatTycker.PatResult elabClauses = patTycker.elabClausesDirectly(ctor.clauses, signature);
        CtorDef elaborated = new CtorDef(dataRef, ctor.ref, pat, ctor.patternTele, tele, elabClauses.matchings(), dataCall, ctor.coerce);
        dataConcrete.checkedBody.append((Object)elaborated);
        if (patTycker.noError()) {
            this.ensureConfluent(tycker, signature, elabClauses, ctor.sourcePos, false);
        }
        return elaborated;
    }

    private void ensureConfluent(ExprTycker tycker, Def.Signature signature, PatTycker.PatResult elabClauses, SourcePos pos, boolean coverage) {
        if (!coverage && elabClauses.matchings().isEmpty()) {
            return;
        }
        this.tracing(builder -> builder.shift(new Trace.LabelT(pos, "confluence check")));
        PatClassifier.confluence(elabClauses, tycker, pos, PatClassifier.classify(elabClauses.clauses(), signature.param(), tycker, pos, coverage));
        Conquer.against(elabClauses.matchings(), true, tycker, pos, signature);
        tycker.solveMetas();
        this.tracing(TreeBuilder::reduce);
    }

    public void tyckHeader(@NotNull Decl.StructField field, ExprTycker tycker) {
        if (field.signature != null) {
            return;
        }
        DefVar<StructDef, Decl.StructDecl> structRef = field.structRef;
        Def.Signature structSig = ((Decl.StructDecl)structRef.concrete).signature;
        assert (structSig != null);
        int structLvl = ((Decl.StructDecl)structRef.concrete).ulift;
        ImmutableSeq<Term.Param> tele = this.tele(tycker, (ImmutableSeq<Expr.Param>)field.telescope, structLvl);
        Term result = tycker.zonk(tycker.inherit(field.result, new FormTerm.Univ(structLvl))).wellTyped();
        field.signature = new Def.Signature(tele, result);
    }

    @NotNull
    public FieldDef tyck(@NotNull Decl.StructField field, ExprTycker tycker) {
        if (field.ref.core != null) {
            return (FieldDef)field.ref.core;
        }
        DefVar<StructDef, Decl.StructDecl> structRef = field.structRef;
        Def.Signature structSig = ((Decl.StructDecl)structRef.concrete).signature;
        assert (structSig != null);
        if (field.signature == null) {
            this.tyckHeader(field, tycker);
        }
        Def.Signature signature = field.signature;
        ImmutableSeq<Term.Param> tele = signature.param();
        Term result = signature.result();
        PatTycker patTycker = new PatTycker(tycker);
        PatTycker.PatResult clauses = patTycker.elabClausesDirectly(field.clauses, field.signature);
        Option body = field.body.map(e -> tycker.inherit((Expr)e, result).wellTyped());
        FieldDef elaborated = new FieldDef(structRef, field.ref, structSig.param(), tele, result, clauses.matchings(), (Option<Term>)body, field.coerce);
        if (patTycker.noError()) {
            this.ensureConfluent(tycker, field.signature, clauses, field.sourcePos, false);
        }
        return elaborated;
    }

    @NotNull
    private ImmutableSeq<Term.Param> tele(@NotNull ExprTycker tycker, @NotNull ImmutableSeq<Expr.Param> tele, int sort) {
        ImmutableSeq<TeleResult> okTele = this.checkTele(tycker, tele, sort);
        tycker.solveMetas();
        return this.zonkTele(tycker, okTele);
    }

    @NotNull
    private ImmutableSeq<TeleResult> checkTele(@NotNull ExprTycker exprTycker, @NotNull ImmutableSeq<Expr.Param> tele, int sort) {
        return tele.map(param -> {
            Term paramTyped = (sort >= 0 ? exprTycker.inherit(param.type(), new FormTerm.Univ(sort)) : exprTycker.synthesize(param.type())).wellTyped();
            Term.Param newParam = new Term.Param((ParamLike<?>)param, paramTyped);
            exprTycker.localCtx.put(newParam);
            return new TeleResult(newParam, param.sourcePos());
        });
    }

    @NotNull
    private ImmutableSeq<Term.Param> zonkTele(@NotNull ExprTycker exprTycker, ImmutableSeq<TeleResult> okTele) {
        return okTele.map(tt -> {
            Term.Param rawParam = tt.param;
            Term.Param param = new Term.Param(rawParam, exprTycker.zonk(rawParam.type()));
            exprTycker.localCtx.put(param);
            return param;
        });
    }

    private record TeleResult(@NotNull Term.Param param, @NotNull SourcePos pos) {
    }
}

