/*
 * Decompiled with CFR 0.152.
 */
package org.aya.tyck;

import java.lang.runtime.SwitchBootstraps;
import java.util.Objects;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import kala.collection.Map;
import kala.collection.SeqLike;
import kala.collection.SeqView;
import kala.collection.immutable.ImmutableMap;
import kala.collection.immutable.ImmutableSeq;
import kala.collection.mutable.MutableMap;
import kala.control.Either;
import kala.control.Option;
import kala.tuple.Tuple;
import kala.tuple.Tuple2;
import kala.tuple.Unit;
import org.aya.api.error.Reporter;
import org.aya.api.error.SourcePos;
import org.aya.api.ref.DefVar;
import org.aya.api.ref.Var;
import org.aya.api.util.Arg;
import org.aya.concrete.Expr;
import org.aya.concrete.Pattern;
import org.aya.concrete.stmt.Decl;
import org.aya.concrete.stmt.Signatured;
import org.aya.concrete.visitor.ExprRefSubst;
import org.aya.core.Matching;
import org.aya.core.def.CtorDef;
import org.aya.core.def.DataDef;
import org.aya.core.def.Def;
import org.aya.core.def.FieldDef;
import org.aya.core.def.FnDef;
import org.aya.core.def.PrimDef;
import org.aya.core.def.StructDef;
import org.aya.core.pat.Pat;
import org.aya.core.sort.LevelSubst;
import org.aya.core.sort.Sort;
import org.aya.core.term.CallTerm;
import org.aya.core.term.FormTerm;
import org.aya.core.term.Term;
import org.aya.core.visitor.Substituter;
import org.aya.generic.GenericBuilder;
import org.aya.generic.Level;
import org.aya.generic.ParamLike;
import org.aya.tyck.ExprTycker;
import org.aya.tyck.LocalCtx;
import org.aya.tyck.pat.Conquer;
import org.aya.tyck.pat.PatClassifier;
import org.aya.tyck.pat.PatTycker;
import org.aya.tyck.trace.Trace;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

public record StmtTycker(@NotNull Reporter reporter,  @Nullable Trace.Builder traceBuilder) {
    @NotNull
    public ExprTycker newTycker() {
        return new ExprTycker(this.reporter, this.traceBuilder);
    }

    private void tracing(@NotNull @NotNull Consumer< @NotNull Trace.Builder> consumer) {
        if (this.traceBuilder != null) {
            consumer.accept(this.traceBuilder);
        }
    }

    private <S extends Signatured, D extends Def> D traced(@NotNull S yeah, ExprTycker p, @NotNull BiFunction<S, ExprTycker, D> f) {
        this.tracing(builder -> builder.shift(new Trace.DeclT(yeah.ref(), yeah.sourcePos)));
        LocalCtx parent = p.localCtx;
        p.localCtx = parent.derive();
        Def r = (Def)f.apply(yeah, p);
        this.tracing(GenericBuilder::reduce);
        p.localCtx = parent;
        return (D)r;
    }

    public Def tyck(@NotNull Decl decl, @NotNull ExprTycker tycker) {
        return this.traced(decl, tycker, this::doTyck);
    }

    private Def doTyck(@NotNull Decl predecl, @NotNull ExprTycker tycker) {
        if (predecl.signature == null) {
            this.tyckHeader(predecl, tycker);
        } else {
            predecl.signature.param().forEach(param -> tycker.localCtx.put(param.ref(), param.type()));
        }
        Def.Signature signature = predecl.signature;
        Decl decl = predecl;
        Objects.requireNonNull(decl);
        Decl decl2 = decl;
        int n = 0;
        return switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{Decl.FnDecl.class, Decl.DataDecl.class, Decl.PrimDecl.class, Decl.StructDecl.class}, (Object)decl2, n)) {
            default -> throw new IncompatibleClassChangeError();
            case 0 -> {
                Decl.FnDecl decl = (Decl.FnDecl)decl2;
                if (!$assertionsDisabled && signature == null) {
                    throw new AssertionError();
                }
                BiFunction<Term, Either<Term, ImmutableSeq<Matching>>, FnDef> factory = FnDef.factory((resultTy, body) -> new FnDef(decl.ref, signature.param(), signature.sortParam(), (Term)resultTy, (Either<Term, ImmutableSeq<Matching>>)body));
                yield (FnDef)decl.body.fold(body -> {
                    ExprTycker.Result result = tycker.zonk((Expr)body, tycker.inherit((Expr)body, signature.result()));
                    Term resultTy = signature.result().zonk(tycker, decl.result.sourcePos());
                    return (FnDef)factory.apply(resultTy, (Either<Term, ImmutableSeq<Matching>>)Either.left((Object)result.wellTyped()));
                }, clauses -> {
                    PatTycker patTycker = new PatTycker(tycker);
                    Tuple2<Term, ImmutableSeq<Pat.PrototypeClause>> result = patTycker.elabClauses((ImmutableSeq<Pattern.Clause>)clauses, signature);
                    ImmutableSeq matchings = ((ImmutableSeq)result._2).flatMap(Pat.PrototypeClause::deprototypify);
                    FnDef def = (FnDef)factory.apply((Term)result._1, (Either<Term, ImmutableSeq<Matching>>)Either.right((Object)matchings));
                    if (patTycker.noError()) {
                        this.ensureConfluent(tycker, signature, (ImmutableSeq<Pat.PrototypeClause>)((ImmutableSeq)result._2), (ImmutableSeq<Matching>)matchings, decl.sourcePos, true);
                    }
                    return def;
                });
            }
            case 1 -> {
                Decl.DataDecl decl = (Decl.DataDecl)decl2;
                if (!$assertionsDisabled && signature == null) {
                    throw new AssertionError();
                }
                ImmutableSeq body = decl.body.map(clause -> this.traced((Signatured)clause, tycker, this::visitCtor));
                yield new DataDef(decl.ref, signature.param(), signature.sortParam(), signature.result(), (ImmutableSeq<CtorDef>)body);
            }
            case 2 -> {
                Decl.PrimDecl decl = (Decl.PrimDecl)decl2;
                yield (PrimDef)decl.ref.core;
            }
            case 3 -> {
                Decl.StructDecl decl = (Decl.StructDecl)decl2;
                if (!$assertionsDisabled && signature == null) {
                    throw new AssertionError();
                }
                Term result = signature.result();
                yield new StructDef(decl.ref, signature.param(), signature.sortParam(), result, (ImmutableSeq<FieldDef>)decl.fields.map(field -> this.traced((Signatured)field, tycker, (f, tyck) -> this.visitField((Decl.StructField)f, (ExprTycker)tyck, result))));
            }
        };
    }

    public void tyckHeader(@NotNull Decl decl, @NotNull ExprTycker tycker) {
        Decl decl2 = decl;
        Objects.requireNonNull(decl2);
        Decl decl3 = decl2;
        int n = 0;
        switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{Decl.FnDecl.class, Decl.DataDecl.class, Decl.StructDecl.class, Decl.PrimDecl.class}, (Object)decl3, n)) {
            default: {
                throw new IncompatibleClassChangeError();
            }
            case 0: {
                Decl.FnDecl fn = (Decl.FnDecl)decl3;
                this.tracing(builder -> builder.shift(new Trace.LabelT(fn.sourcePos, "telescope")));
                ImmutableSeq<Term.Param> resultTele = this.checkTele(tycker, (ImmutableSeq<Expr.Param>)fn.telescope, FormTerm.freshUniv(fn.sourcePos));
                Term resultRes = tycker.synthesize(fn.result).wellTyped();
                this.tracing(GenericBuilder::reduce);
                fn.signature = new Def.Signature(tycker.extractLevels(), resultTele, resultRes);
                break;
            }
            case 1: {
                Decl.DataDecl data = (Decl.DataDecl)decl3;
                SourcePos pos = data.sourcePos;
                ImmutableSeq<Term.Param> tele = this.checkTele(tycker, (ImmutableSeq<Expr.Param>)data.telescope, FormTerm.freshUniv(pos));
                FormTerm.Univ result = data.result instanceof Expr.HoleExpr ? FormTerm.Univ.ZERO : tycker.zonk(data.result, tycker.inherit(data.result, FormTerm.freshUniv(pos))).wellTyped();
                data.signature = new Def.Signature(tycker.extractLevels(), tele, result);
                break;
            }
            case 2: {
                Decl.StructDecl struct = (Decl.StructDecl)decl3;
                SourcePos pos = struct.sourcePos;
                ImmutableSeq<Term.Param> tele = this.checkTele(tycker, (ImmutableSeq<Expr.Param>)struct.telescope, FormTerm.freshUniv(pos));
                Term result = tycker.zonk(struct.result, tycker.inherit(struct.result, FormTerm.freshUniv(pos))).wellTyped();
                struct.signature = new Def.Signature(tycker.extractLevels(), tele, result);
                break;
            }
            case 3: {
                Decl.PrimDecl prim = (Decl.PrimDecl)decl3;
                assert (tycker.localCtx.isEmpty());
                PrimDef core = (PrimDef)prim.ref.core;
                ImmutableSeq<Term.Param> tele = this.checkTele(tycker, (ImmutableSeq<Expr.Param>)prim.telescope, FormTerm.freshUniv(prim.sourcePos));
                if (tele.isNotEmpty()) {
                    if (prim.result == null) {
                        throw new ExprTycker.TyckerException();
                    }
                    Term result = tycker.synthesize(prim.result).wellTyped();
                    LevelSubst.Simple levelSubst = new LevelSubst.Simple((MutableMap<Sort.LvlVar, Sort>)MutableMap.create());
                    ImmutableSeq<Sort.LvlVar> levels = tycker.extractLevels();
                    for (Tuple2 lvl : core.levels.zip(levels)) {
                        levelSubst.solution().put((Object)((Sort.LvlVar)lvl._1), (Object)new Sort(new Level.Reference<Sort.LvlVar>((Sort.LvlVar)lvl._2)));
                    }
                    Term target = FormTerm.Pi.make(core.telescope(), core.result()).subst(Substituter.TermSubst.EMPTY, levelSubst);
                    tycker.unifyTyReported(FormTerm.Pi.make(tele, result), target, prim.result);
                    prim.signature = new Def.Signature(levels, tele, result);
                } else if (prim.result != null) {
                    Term result = tycker.synthesize(prim.result).wellTyped();
                    tycker.unifyTyReported(result, core.result(), prim.result);
                } else {
                    prim.signature = new Def.Signature((ImmutableSeq<Sort.LvlVar>)ImmutableSeq.empty(), core.telescope(), core.result());
                }
                tycker.solveMetas();
            }
        }
    }

    @NotNull
    private CtorDef visitCtor(@NotNull Decl.DataCtor ctor, ExprTycker tycker) {
        Def.Signature signature;
        DefVar<DataDef, Decl.DataDecl> dataRef = ctor.dataRef;
        Def.Signature dataSig = ((Decl.DataDecl)dataRef.concrete).signature;
        assert (dataSig != null);
        ImmutableSeq dataArgs = dataSig.param().map(Term.Param::toArg);
        ImmutableSeq<Sort.LvlVar> sortParam = dataSig.sortParam();
        CallTerm.Data dataCall = new CallTerm.Data(dataRef, (ImmutableSeq<Sort>)sortParam.view().map(Level.Reference::new).map(Sort::new).toImmutableSeq(), (ImmutableSeq<Arg<Term>>)dataArgs);
        Def.Signature sig = new Def.Signature(sortParam, dataSig.param(), dataCall);
        PatTycker patTycker = new PatTycker(tycker);
        ImmutableSeq pat = ctor.patterns.isNotEmpty() ? (ImmutableSeq)patTycker.visitPatterns((Def.Signature)sig, (SeqView<Pattern>)ctor.patterns.view())._1 : ImmutableSeq.empty();
        ImmutableSeq<Term.Param> tele = this.checkTele(tycker, (ImmutableSeq<Expr.Param>)ctor.telescope.map(param -> param.mapExpr(expr -> expr.accept(patTycker.refSubst, Unit.unit()))), dataSig.result());
        ctor.signature = signature = new Def.Signature(sortParam, tele, dataCall);
        ExprRefSubst patSubst = patTycker.refSubst.clone();
        SeqView dataParamView = dataSig.param().view();
        if (pat.isNotEmpty()) {
            ImmutableMap subst = dataParamView.map(Term.Param::ref).zip((Iterable)pat.view().map(Pat::toTerm)).toImmutableMap();
            dataCall = (CallTerm.Data)dataCall.subst((Map<Var, Term>)subst);
        }
        ImmutableSeq<Pat.PrototypeClause> elabClauses = patTycker.elabClauses(patSubst, signature, ctor.clauses);
        ImmutableSeq matchings = elabClauses.flatMap(Pat.PrototypeClause::deprototypify);
        ImmutableSeq implicits = pat.isEmpty() ? dataParamView.map(Term.Param::implicitify).toImmutableSeq() : Pat.extractTele((SeqLike<Pat>)pat);
        CtorDef elaborated = new CtorDef(dataRef, ctor.ref, (ImmutableSeq<Pat>)pat, (ImmutableSeq<Term.Param>)implicits, tele, (ImmutableSeq<Matching>)matchings, dataCall, ctor.coerce);
        if (patTycker.noError()) {
            this.ensureConfluent(tycker, signature, elabClauses, (ImmutableSeq<Matching>)matchings, ctor.sourcePos, false);
        }
        return elaborated;
    }

    private void ensureConfluent(ExprTycker tycker, Def.Signature signature, ImmutableSeq<Pat.PrototypeClause> elabClauses, ImmutableSeq<@NotNull Matching> matchings, @NotNull SourcePos pos, boolean coverage) {
        if (!matchings.isNotEmpty()) {
            return;
        }
        this.tracing(builder -> builder.shift(new Trace.LabelT(pos, "confluence check")));
        ImmutableSeq<PatClassifier.PatClass> classification = PatClassifier.classify(elabClauses, signature.param(), tycker.state, tycker.reporter, pos, coverage);
        PatClassifier.confluence(elabClauses, tycker, pos, signature.result(), classification);
        Conquer.against(matchings, tycker, pos, signature);
        tycker.solveMetas();
        this.tracing(GenericBuilder::reduce);
    }

    @NotNull
    private FieldDef visitField(@NotNull Decl.StructField field, ExprTycker tycker, @NotNull Term structResult) {
        ImmutableSeq<Term.Param> tele = this.checkTele(tycker, (ImmutableSeq<Expr.Param>)field.telescope, structResult);
        DefVar<StructDef, Decl.StructDecl> structRef = field.structRef;
        Term result = tycker.zonk(field.result, tycker.inherit(field.result, structResult)).wellTyped();
        Def.Signature structSig = ((Decl.StructDecl)structRef.concrete).signature;
        assert (structSig != null);
        field.signature = new Def.Signature(structSig.sortParam(), tele, result);
        PatTycker patTycker = new PatTycker(tycker);
        ImmutableSeq<Pat.PrototypeClause> elabClauses = patTycker.elabClauses(null, field.signature, field.clauses);
        ImmutableSeq matchings = elabClauses.flatMap(Pat.PrototypeClause::deprototypify);
        Option body = field.body.map(e -> tycker.inherit((Expr)e, result).wellTyped());
        FieldDef elaborated = new FieldDef(structRef, field.ref, structSig.param(), tele, result, (ImmutableSeq<Matching>)matchings, (Option<Term>)body, field.coerce);
        if (patTycker.noError()) {
            this.ensureConfluent(tycker, field.signature, elabClauses, (ImmutableSeq<Matching>)matchings, field.sourcePos, false);
        }
        return elaborated;
    }

    @NotNull
    private ImmutableSeq<Term.Param> checkTele(@NotNull ExprTycker exprTycker, @NotNull ImmutableSeq<Expr.Param> tele, @NotNull Term univ) {
        ImmutableSeq okTele = tele.map(param -> {
            assert (param.type() != null);
            Term paramTyped = exprTycker.inherit(param.type(), univ).wellTyped();
            exprTycker.localCtx.put(param.ref(), paramTyped);
            return Tuple.of((Object)new Term.Param((ParamLike<?>)param, paramTyped), (Object)param.sourcePos());
        });
        exprTycker.solveMetas();
        return okTele.map(tt -> {
            Term.Param t = (Term.Param)tt._1;
            Term term = t.type().zonk(exprTycker, (SourcePos)tt._2);
            exprTycker.localCtx.put(t.ref(), term);
            return new Term.Param(t, term);
        });
    }
}

