/*
 * Decompiled with CFR 0.152.
 */
package org.aya.core.serde;

import java.util.Objects;
import kala.collection.immutable.ImmutableSeq;
import kala.collection.mutable.MutableMap;
import kala.control.Either;
import kala.control.Option;
import kala.tuple.Unit;
import org.aya.api.ref.DefVar;
import org.aya.api.ref.LocalVar;
import org.aya.api.util.Arg;
import org.aya.core.Matching;
import org.aya.core.def.CtorDef;
import org.aya.core.def.DataDef;
import org.aya.core.def.Def;
import org.aya.core.def.FieldDef;
import org.aya.core.def.FnDef;
import org.aya.core.def.PrimDef;
import org.aya.core.def.StructDef;
import org.aya.core.pat.Pat;
import org.aya.core.serde.SerDef;
import org.aya.core.serde.SerLevel;
import org.aya.core.serde.SerPat;
import org.aya.core.serde.SerTerm;
import org.aya.core.sort.Sort;
import org.aya.core.term.CallTerm;
import org.aya.core.term.ErrorTerm;
import org.aya.core.term.RefTerm;
import org.aya.core.term.Term;
import org.jetbrains.annotations.Contract;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

public record Serializer(@NotNull State state) implements Term.Visitor<Unit, SerTerm>,
Def.Visitor<Unit, SerDef>,
Pat.Visitor<Unit, SerPat>
{
    @NotNull
    private SerTerm serialize(@NotNull Term term) {
        return term.accept(this, Unit.unit());
    }

    @NotNull
    private SerPat serialize(@NotNull Pat pat) {
        return pat.accept(this, Unit.unit());
    }

    @NotNull
    private SerPat.Matchy serialize(@NotNull Matching matchy) {
        return new SerPat.Matchy(this.serializePats(matchy.patterns()), this.serialize(matchy.body()));
    }

    private SerTerm.SerArg serialize(@NotNull @NotNull Arg<@NotNull Term> termArg) {
        return new SerTerm.SerArg(this.serialize((Term)termArg.term()), termArg.explicit());
    }

    @Contract(value="_ -> new")
    private SerTerm.SerParam serialize(@NotNull Term.Param param) {
        return new SerTerm.SerParam(param.explicit(), param.pattern(), this.state.local(param.ref()), this.serialize(param.type()));
    }

    @NotNull
    private ImmutableSeq<SerTerm.SerParam> serializeParams(ImmutableSeq<@NotNull Term.Param> params) {
        return params.map(this::serialize);
    }

    @Override
    public SerTerm visitError(@NotNull ErrorTerm term, Unit unit) {
        throw new AssertionError((Object)"Shall not have error term serialized.");
    }

    @Override
    public SerTerm visitHole( @NotNull CallTerm.Hole term, Unit unit) {
        throw new AssertionError((Object)"Shall not have holes serialized.");
    }

    @Override
    public SerTerm visitFieldRef(@NotNull RefTerm.Field term, Unit unit) {
        return new SerTerm.FieldRef(this.state.def(term.ref()));
    }

    @Override
    public SerTerm visitRef(@NotNull RefTerm term, Unit unit) {
        return new SerTerm.Ref(this.state.local(term.var()), this.serialize(term.type()));
    }

    @Override
    public SerTerm visitLam( @NotNull IntroTerm.Lambda term, Unit unit) {
        return new SerTerm.Lam(this.serialize(term.param()), this.serialize(term.body()));
    }

    @Override
    public SerTerm visitPi( @NotNull FormTerm.Pi term, Unit unit) {
        return new SerTerm.Pi(this.serialize(term.param()), this.serialize(term.body()));
    }

    @Override
    public SerTerm visitSigma( @NotNull FormTerm.Sigma term, Unit unit) {
        return new SerTerm.Sigma(this.serializeParams(term.params()));
    }

    @Override
    public SerTerm visitUniv( @NotNull FormTerm.Univ term, Unit unit) {
        return new SerTerm.Univ(this.serialize(term.sort()));
    }

    @NotNull
    private ImmutableSeq<SerTerm.SerArg> serializeArgs(@NotNull ImmutableSeq<Arg<Term>> args) {
        return args.map(this::serialize);
    }

    private @NotNull SerLevel.Max serialize(@NotNull Sort level) {
        return SerLevel.ser(level, this.state.levelCache());
    }

    @NotNull
    private ImmutableSeq<SerLevel.Max> serializeLevels(@NotNull ImmutableSeq<Sort> sortArgs) {
        return sortArgs.map(this::serialize);
    }

    @NotNull
    private ImmutableSeq<SerPat> serializePats(@NotNull ImmutableSeq<Pat> pats) {
        return pats.map(this::serialize);
    }

    @Override
    public SerTerm visitApp( @NotNull ElimTerm.App term, Unit unit) {
        return new SerTerm.App(this.serialize(term.of()), this.serialize(term.arg()));
    }

    @NotNull
    private SerTerm.CallData serializeCall(@NotNull @NotNull ImmutableSeq<@NotNull Sort> sortArgs, @NotNull @NotNull ImmutableSeq<Arg<@NotNull Term>> args) {
        return new SerTerm.CallData(this.serializeLevels(sortArgs), this.serializeArgs(args));
    }

    @Override
    public SerTerm visitFnCall(@NotNull CallTerm.Fn fnCall, Unit unit) {
        return new SerTerm.FnCall(this.state.def(fnCall.ref()), this.serializeCall(fnCall.sortArgs(), fnCall.args()));
    }

    @Override
    public SerTerm.DataCall visitDataCall(@NotNull CallTerm.Data dataCall, Unit unit) {
        return new SerTerm.DataCall(this.state.def(dataCall.ref()), this.serializeCall(dataCall.sortArgs(), dataCall.args()));
    }

    @Override
    public SerTerm visitConCall(@NotNull CallTerm.Con conCall, Unit unit) {
        return new SerTerm.ConCall(this.state.def(conCall.head().dataRef()), this.state.def(conCall.head().ref()), this.serializeCall(conCall.head().sortArgs(), conCall.head().dataArgs()), this.serializeArgs(conCall.args()));
    }

    @Override
    public SerTerm visitStructCall(@NotNull CallTerm.Struct structCall, Unit unit) {
        return new SerTerm.StructCall(this.state.def(structCall.ref()), this.serializeCall(structCall.sortArgs(), structCall.args()));
    }

    @Override
    public SerTerm visitPrimCall( @NotNull CallTerm.Prim prim, Unit unit) {
        return new SerTerm.PrimCall(this.state.def(prim.ref()), this.serializeCall(prim.sortArgs(), prim.args()));
    }

    @Override
    public SerTerm visitTup( @NotNull IntroTerm.Tuple term, Unit unit) {
        return new SerTerm.Tup((ImmutableSeq<SerTerm>)term.items().map(this::serialize));
    }

    @Override
    public SerTerm visitNew( @NotNull IntroTerm.New newTerm, Unit unit) {
        return new SerTerm.New(new SerTerm.StructCall(this.state.def(newTerm.struct().ref()), this.serializeCall(newTerm.struct().sortArgs(), newTerm.struct().args())));
    }

    @Override
    public SerTerm visitProj( @NotNull ElimTerm.Proj term, Unit unit) {
        return new SerTerm.Proj(this.serialize(term.of()), term.ix());
    }

    @Override
    public SerTerm visitAccess( @NotNull CallTerm.Access term, Unit unit) {
        return new SerTerm.Access(this.serialize(term.of()), this.state.def(term.ref()), this.serializeLevels(term.sortArgs()), this.serializeArgs(term.structArgs()), this.serializeArgs(term.fieldArgs()));
    }

    @Override
    public SerPat visitBind(@NotNull Pat.Bind bind, Unit unit) {
        return new SerPat.Bind(bind.explicit(), this.state.local(bind.as()), this.serialize(bind.type()));
    }

    @Override
    public SerPat visitTuple(@NotNull Pat.Tuple tuple, Unit unit) {
        return new SerPat.Tuple(tuple.explicit(), this.serializePats(tuple.pats()), this.state.localMaybe(tuple.as()), this.serialize(tuple.type()));
    }

    @Override
    public SerPat visitCtor(@NotNull Pat.Ctor ctor, Unit unit) {
        return new SerPat.Ctor(ctor.explicit(), this.state.def(ctor.ref()), this.serializePats(ctor.params()), this.state.localMaybe(ctor.as()), this.visitDataCall(ctor.type(), unit));
    }

    @Override
    public SerPat visitAbsurd(@NotNull Pat.Absurd absurd, Unit unit) {
        return new SerPat.Absurd(absurd.explicit(), this.serialize(absurd.type()));
    }

    @Override
    public SerPat visitPrim(@NotNull Pat.Prim prim, Unit unit) {
        return new SerPat.Prim(prim.explicit(), this.state.def(prim.ref()), this.serialize(prim.type()));
    }

    @Override
    public SerDef visitFn(@NotNull FnDef def, Unit unit) {
        return new SerDef.Fn(this.state.def(def.ref), this.serializeParams((ImmutableSeq<Term.Param>)def.telescope), (ImmutableSeq<SerLevel.LvlVar>)def.levels.map(lvl -> SerLevel.ser(lvl, this.state.levelCache)), (Either<SerTerm, ImmutableSeq<SerPat.Matchy>>)def.body.map(this::serialize, matchings -> matchings.map(this::serialize)), this.serialize(def.result));
    }

    @Override
    public SerDef visitData(@NotNull DataDef def, Unit unit) {
        return new SerDef.Data(this.state.def(def.ref), this.serializeParams((ImmutableSeq<Term.Param>)def.telescope), (ImmutableSeq<SerLevel.LvlVar>)def.levels.map(lvl -> SerLevel.ser(lvl, this.state.levelCache)), this.serialize(def.result), (ImmutableSeq<SerDef.Ctor>)def.body.map(ctor -> this.visitCtor((CtorDef)ctor, Unit.unit())));
    }

    @Override
    public SerDef.Ctor visitCtor(@NotNull CtorDef def, Unit unit) {
        return new SerDef.Ctor(this.state.def(def.dataRef), this.state.def(def.ref), this.serializePats(def.pats), this.serializeParams((ImmutableSeq<Term.Param>)def.ownerTele), this.serializeParams((ImmutableSeq<Term.Param>)def.selfTele), (ImmutableSeq<SerPat.Matchy>)def.clauses.map(this::serialize), this.serialize(def.result), def.coerce);
    }

    @Override
    public SerDef visitStruct(@NotNull StructDef def, Unit unit) {
        return new SerDef.Struct(this.state.def(def.ref()), this.serializeParams((ImmutableSeq<Term.Param>)def.telescope), (ImmutableSeq<SerLevel.LvlVar>)def.levels.map(lvl -> SerLevel.ser(lvl, this.state.levelCache)), this.serialize(def.result), (ImmutableSeq<SerDef.Field>)def.fields.map(field -> this.visitField((FieldDef)field, Unit.unit())));
    }

    @Override
    public SerDef.Field visitField(@NotNull FieldDef def, Unit unit) {
        return new SerDef.Field(this.state.def(def.structRef), this.state.def(def.ref), this.serializeParams((ImmutableSeq<Term.Param>)def.ownerTele), this.serializeParams((ImmutableSeq<Term.Param>)def.selfTele), this.serialize(def.result), (ImmutableSeq<SerPat.Matchy>)def.clauses.map(this::serialize), (Option<SerTerm>)def.body.map(this::serialize), def.coerce);
    }

    @Override
    public SerDef visitPrim(@NotNull PrimDef def, Unit unit) {
        return new SerDef.Prim(this.serializeParams((ImmutableSeq<Term.Param>)def.telescope), (ImmutableSeq<SerLevel.LvlVar>)def.levels.map(lvl -> SerLevel.ser(lvl, this.state.levelCache)), this.serialize(def.result), Objects.requireNonNull(PrimDef.ID.find(def.ref.name())));
    }

    public record State(@NotNull MutableMap<Sort.LvlVar, Integer> levelCache, @NotNull MutableMap<LocalVar, Integer> localCache, @NotNull MutableMap<DefVar<?, ?>, Integer> defCache) {
        public State() {
            this((MutableMap<Sort.LvlVar, Integer>)MutableMap.create(), (MutableMap<LocalVar, Integer>)MutableMap.create(), MutableMap.create());
        }

        @NotNull
        public SerTerm.SimpVar local(@NotNull LocalVar var) {
            return new SerTerm.SimpVar((Integer)this.localCache.getOrPut((Object)var, () -> this.localCache.size()), var.name());
        }

        @NotNull
        public SerTerm.SimpVar localMaybe(@Nullable LocalVar var) {
            if (var == null) {
                return new SerTerm.SimpVar(-1, "");
            }
            return this.local(var);
        }

        @NotNull
        public SerDef.QName def(@NotNull DefVar<?, ?> var) {
            assert (var.module != null);
            return new SerDef.QName((ImmutableSeq<String>)var.module, var.name(), (Integer)this.defCache.getOrPut(var, () -> this.defCache.size()));
        }
    }
}

