/*
 * Decompiled with CFR 0.152.
 */
package org.aya.tyck;

import kala.collection.immutable.ImmutableSeq;
import kala.tuple.Unit;
import org.aya.api.ref.DefVar;
import org.aya.api.util.NormalizeMode;
import org.aya.concrete.stmt.Decl;
import org.aya.core.def.Def;
import org.aya.core.def.FieldDef;
import org.aya.core.def.StructDef;
import org.aya.core.sort.Sort;
import org.aya.core.term.CallTerm;
import org.aya.core.term.ElimTerm;
import org.aya.core.term.ErrorTerm;
import org.aya.core.term.FormTerm;
import org.aya.core.term.IntroTerm;
import org.aya.core.term.RefTerm;
import org.aya.core.term.Term;
import org.aya.core.visitor.Substituter;
import org.aya.core.visitor.Unfolder;
import org.aya.generic.Constants;
import org.aya.tyck.TyckState;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

public record LittleTyper(@Nullable TyckState state) implements Term.Visitor<Unit, Term>
{
    @Override
    public Term visitRef(@NotNull RefTerm term, Unit unit) {
        return term.type();
    }

    @Override
    public Term visitLam(@NotNull IntroTerm.Lambda term, Unit unit) {
        return new FormTerm.Pi(term.param(), term.body().accept(this, unit));
    }

    @Override
    public Term visitPi(@NotNull FormTerm.Pi term, Unit unit) {
        Term paramTyRaw = term.param().type().accept(this, Unit.unit()).normalize(this.state, NormalizeMode.WHNF);
        Term retTyRaw = term.body().accept(this, Unit.unit()).normalize(this.state, NormalizeMode.WHNF);
        if (paramTyRaw instanceof FormTerm.Univ) {
            FormTerm.Univ paramTy = (FormTerm.Univ)paramTyRaw;
            if (retTyRaw instanceof FormTerm.Univ) {
                FormTerm.Univ retTy = (FormTerm.Univ)retTyRaw;
                return new FormTerm.Univ(Sort.max(paramTy.sort(), retTy.sort()));
            }
        }
        return ErrorTerm.typeOf(term);
    }

    @Override
    public Term visitError(@NotNull ErrorTerm term, Unit unit) {
        return ErrorTerm.typeOf(term);
    }

    @Override
    public Term visitSigma(@NotNull FormTerm.Sigma term, Unit unit) {
        ImmutableSeq univ = term.params().view().map(param -> param.type().accept(this, Unit.unit()).normalize(this.state, NormalizeMode.WHNF)).filterIsInstance(FormTerm.Univ.class).toImmutableSeq();
        if (univ.sizeEquals(term.params().size())) {
            return new FormTerm.Univ((Sort)univ.view().map(FormTerm.Univ::sort).reduce(Sort::max));
        }
        return ErrorTerm.typeOf(term);
    }

    @Override
    public Term visitUniv(@NotNull FormTerm.Univ term, Unit unit) {
        return new FormTerm.Univ(term.sort().lift(1));
    }

    @Override
    public Term visitApp(@NotNull ElimTerm.App term, Unit unit) {
        Term term2;
        Term piRaw = term.of().accept(this, unit).normalize(this.state, NormalizeMode.WHNF);
        if (piRaw instanceof FormTerm.Pi) {
            FormTerm.Pi pi = (FormTerm.Pi)piRaw;
            term2 = pi.substBody((Term)term.arg().term());
        } else {
            term2 = ErrorTerm.typeOf(term);
        }
        return term2;
    }

    @Override
    public Term visitFnCall(@NotNull CallTerm.Fn fnCall, Unit unit) {
        return this.defCall(fnCall.ref(), fnCall.sortArgs());
    }

    @Override
    public Term visitDataCall(@NotNull CallTerm.Data dataCall, Unit unit) {
        return this.defCall(dataCall.ref(), dataCall.sortArgs());
    }

    @Override
    public Term visitConCall(@NotNull CallTerm.Con conCall, Unit unit) {
        return this.defCall(conCall.head().dataRef(), conCall.sortArgs());
    }

    @Override
    public Term visitStructCall(@NotNull CallTerm.Struct structCall, Unit unit) {
        return this.defCall(structCall.ref(), structCall.sortArgs());
    }

    @NotNull
    private Term defCall(DefVar<? extends Def, ? extends Decl> ref, ImmutableSeq<@NotNull Sort> sortArgs) {
        ImmutableSeq<Sort.LvlVar> levels = Def.defLevels(ref);
        return Def.defResult(ref).subst(Substituter.TermSubst.EMPTY, Unfolder.buildSubst(levels, sortArgs));
    }

    @Override
    public Term visitPrimCall(@NotNull CallTerm.Prim prim, Unit unit) {
        return this.defCall(prim.ref(), prim.sortArgs());
    }

    @Override
    public Term visitTup(@NotNull IntroTerm.Tuple term, Unit unit) {
        return new FormTerm.Sigma((ImmutableSeq<Term.Param>)term.items().map(item -> new Term.Param(Constants.anonymous(), item.accept(this, Unit.unit()), true)));
    }

    @Override
    public Term visitNew(@NotNull IntroTerm.New newTerm, Unit unit) {
        return newTerm.struct();
    }

    @Override
    public Term visitProj(@NotNull ElimTerm.Proj term, Unit unit) {
        Term sigmaRaw = term.of().accept(this, unit).normalize(this.state, NormalizeMode.WHNF);
        if (!(sigmaRaw instanceof FormTerm.Sigma)) {
            return ErrorTerm.typeOf(term);
        }
        FormTerm.Sigma sigma = (FormTerm.Sigma)sigmaRaw;
        int index = term.ix() - 1;
        ImmutableSeq<Term.Param> telescope = sigma.params();
        return ((Term.Param)telescope.get(index)).type().subst(ElimTerm.Proj.projSubst(term.of(), index, telescope));
    }

    @Override
    public Term visitAccess(@NotNull CallTerm.Access term, Unit unit) {
        Term callRaw = term.of().accept(this, unit).normalize(this.state, NormalizeMode.WHNF);
        if (!(callRaw instanceof CallTerm.Struct)) {
            return ErrorTerm.typeOf(term);
        }
        CallTerm.Struct call = (CallTerm.Struct)callRaw;
        FieldDef core = (FieldDef)term.ref().core;
        Substituter.TermSubst subst = Unfolder.buildSubst(core.telescope(), term.fieldArgs()).add(Unfolder.buildSubst(((StructDef)call.ref().core).telescope(), term.structArgs()));
        return core.result().subst(subst);
    }

    @Override
    public Term visitHole(@NotNull CallTerm.Hole term, Unit unit) {
        return term.ref().result;
    }

    @Override
    public Term visitFieldRef(@NotNull RefTerm.Field term, Unit unit) {
        return Def.defType(term.ref());
    }
}

