/*
 * Decompiled with CFR 0.152.
 */
package org.aya.tyck.tycker;

import java.lang.runtime.SwitchBootstraps;
import java.util.Objects;
import java.util.function.BiFunction;
import java.util.function.UnaryOperator;
import kala.collection.Seq;
import kala.collection.SeqView;
import kala.collection.immutable.ImmutableArray;
import kala.collection.immutable.ImmutableSeq;
import kala.function.CheckedBiFunction;
import org.aya.generic.Modifier;
import org.aya.generic.stmt.Shaped;
import org.aya.syntax.compile.JitCon;
import org.aya.syntax.compile.JitData;
import org.aya.syntax.compile.JitDef;
import org.aya.syntax.compile.JitFn;
import org.aya.syntax.compile.JitPrim;
import org.aya.syntax.concrete.stmt.decl.ClassDecl;
import org.aya.syntax.concrete.stmt.decl.ClassMember;
import org.aya.syntax.concrete.stmt.decl.DataCon;
import org.aya.syntax.concrete.stmt.decl.DataDecl;
import org.aya.syntax.concrete.stmt.decl.Decl;
import org.aya.syntax.concrete.stmt.decl.FnDecl;
import org.aya.syntax.concrete.stmt.decl.PrimDecl;
import org.aya.syntax.core.Jdg;
import org.aya.syntax.core.def.AnyDef;
import org.aya.syntax.core.def.ClassDef;
import org.aya.syntax.core.def.ClassDefLike;
import org.aya.syntax.core.def.ConDef;
import org.aya.syntax.core.def.ConDefLike;
import org.aya.syntax.core.def.DataDef;
import org.aya.syntax.core.def.DataDefLike;
import org.aya.syntax.core.def.FnDef;
import org.aya.syntax.core.def.FnDefLike;
import org.aya.syntax.core.def.MemberDef;
import org.aya.syntax.core.def.MemberDefLike;
import org.aya.syntax.core.def.PrimDef;
import org.aya.syntax.core.def.PrimDefLike;
import org.aya.syntax.core.repr.AyaShape;
import org.aya.syntax.core.repr.ShapeRecognition;
import org.aya.syntax.core.term.Term;
import org.aya.syntax.core.term.call.ClassCall;
import org.aya.syntax.core.term.call.ConCall;
import org.aya.syntax.core.term.call.DataCall;
import org.aya.syntax.core.term.call.FnCall;
import org.aya.syntax.core.term.call.MemberCall;
import org.aya.syntax.core.term.call.PrimCall;
import org.aya.syntax.core.term.call.RuleReducer;
import org.aya.syntax.ref.DefVar;
import org.aya.syntax.ref.LocalVar;
import org.aya.syntax.telescope.AbstractTele;
import org.aya.tyck.TyckState;
import org.aya.tyck.tycker.AbstractTycker;
import org.aya.tyck.tycker.Stateful;
import org.aya.util.Panic;
import org.aya.util.position.SourcePos;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

public record AppTycker<Ex extends Exception>(@NotNull TyckState state, @NotNull AbstractTycker tycker, @NotNull SourcePos pos, int argsCount, int lift, @NotNull Factory<Ex> makeArgs) implements Stateful
{
    public AppTycker(@NotNull AbstractTycker tycker, @NotNull SourcePos pos, int argsCount, int lift, @NotNull Factory<Ex> makeArgs) {
        this(tycker.state, tycker, pos, argsCount, lift, makeArgs);
    }

    @NotNull
    public Jdg checkCompiledApplication(@NotNull JitDef def) throws Ex {
        JitDef jitDef = def;
        Objects.requireNonNull(jitDef);
        JitDef jitDef2 = jitDef;
        int n = 0;
        return switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{JitFn.class, JitData.class, JitPrim.class, JitCon.class}, (JitDef)jitDef2, n)) {
            case 0 -> {
                JitFn fn = (JitFn)jitDef2;
                int shape = fn.metadata().shape();
                Shaped.Applicable operator = shape != -1 ? AyaShape.ofFn((FnDefLike)fn, (AyaShape)AyaShape.values()[shape]) : null;
                yield this.checkFnCall((FnDefLike)fn, (Shaped.Applicable<FnDefLike>)operator);
            }
            case 1 -> {
                JitData data = (JitData)jitDef2;
                yield this.checkDataCall((DataDefLike)data);
            }
            case 2 -> {
                JitPrim prim = (JitPrim)jitDef2;
                yield this.checkPrimCall((PrimDefLike)prim);
            }
            case 3 -> {
                JitCon con = (JitCon)jitDef2;
                yield this.checkConCall((ConDefLike)con);
            }
            default -> throw new Panic(def.getClass().getCanonicalName());
        };
    }

    @NotNull
    public Jdg checkDefApplication(@NotNull DefVar<?, ?> defVar) throws Ex {
        Decl decl = defVar.concrete;
        Objects.requireNonNull(decl);
        Decl decl2 = decl;
        int n = 0;
        return switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{FnDecl.class, DataDecl.class, PrimDecl.class, DataCon.class, ClassDecl.class, ClassMember.class, Decl.class}, (Decl)decl2, n)) {
            case 0 -> {
                FnDef.Delegate fnDef = new FnDef.Delegate(defVar);
                Shaped.Applicable op = (Shaped.Applicable)this.state.shapeFactory.find((AnyDef)fnDef).map(recog -> AyaShape.ofFn((FnDefLike)fnDef, (AyaShape)recog.shape())).getOrNull();
                yield this.checkFnCall((FnDefLike)fnDef, (Shaped.Applicable<FnDefLike>)op);
            }
            case 1 -> this.checkDataCall((DataDefLike)new DataDef.Delegate(defVar));
            case 2 -> this.checkPrimCall((PrimDefLike)new PrimDef.Delegate(defVar));
            case 3 -> this.checkConCall((ConDefLike)new ConDef.Delegate(defVar));
            case 4 -> this.checkClassCall((ClassDefLike)new ClassDef.Delegate(defVar));
            case 5 -> this.checkProjCall((MemberDefLike)new MemberDef.Delegate(defVar));
            default -> {
                Decl any = decl2;
                throw new Panic(any.getClass().getCanonicalName());
            }
        };
    }

    @NotNull
    private Jdg checkConCall(@NotNull ConDefLike conVar) throws Ex {
        DataDefLike dataVar = conVar.dataRef();
        AbstractTele fullSignature = conVar.signature().lift(this.lift);
        return (Jdg)this.makeArgs.applyChecked(fullSignature, (args, term) -> {
            ImmutableArray realArgs = ImmutableArray.from((Object[])args);
            ImmutableSeq ownerArgs = realArgs.take(conVar.ownerTeleSize());
            ImmutableSeq conArgs = realArgs.drop(conVar.ownerTeleSize());
            DataCall type = (DataCall)fullSignature.result((Seq)realArgs);
            Shaped.Applicable shape = (Shaped.Applicable)this.state.shapeFactory.find((AnyDef)dataVar).mapNotNull(recog -> AyaShape.ofCon((ConDefLike)conVar, (ShapeRecognition)recog, (DataCall)type)).getOrNull();
            if (shape != null) {
                return new Jdg.Default((Term)new RuleReducer.Con(shape, 0, ownerArgs, conArgs), (Term)type);
            }
            ConCall wellTyped = new ConCall(conVar, ownerArgs, 0, conArgs);
            return new Jdg.Default((Term)wellTyped, (Term)type);
        });
    }

    @NotNull
    private Jdg checkPrimCall(@NotNull PrimDefLike primVar) throws Ex {
        AbstractTele signature = primVar.signature().lift(this.lift);
        return (Jdg)this.makeArgs.applyChecked(signature, (args, term) -> new Jdg.Default(this.state.primFactory.unfold(new PrimCall(primVar, 0, (ImmutableSeq)ImmutableArray.from((Object[])args)), this.state), signature.result(args)));
    }

    @NotNull
    private Jdg checkDataCall(@NotNull DataDefLike data) throws Ex {
        AbstractTele signature = data.signature().lift(this.lift);
        return (Jdg)this.makeArgs.applyChecked(signature, (args, term) -> new Jdg.Default((Term)new DataCall(data, 0, (ImmutableSeq)ImmutableArray.from((Object[])args)), signature.result(args)));
    }

    @NotNull
    private Jdg checkFnCall(@NotNull FnDefLike fnDef, @Nullable Shaped.Applicable<FnDefLike> operator) throws Ex {
        AbstractTele signature = fnDef.signature().lift(this.lift);
        return (Jdg)this.makeArgs.applyChecked(signature, (args, term) -> {
            FnCall fnCall;
            ImmutableArray argsSeq = ImmutableArray.from((Object[])args);
            Term result = signature.result(args);
            if (operator != null) {
                return new Jdg.Default((Term)new RuleReducer.Fn(operator, 0, (ImmutableSeq)argsSeq), result);
            }
            if (fnDef.is(Modifier.Inline)) {
                FnDefLike fnDefLike = fnDef;
                Objects.requireNonNull(fnDefLike);
                FnDefLike selector0$temp = fnDefLike;
                int index$1 = 0;
                fnCall = switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{JitFn.class, FnDef.Delegate.class}, (FnDefLike)selector0$temp, index$1)) {
                    default -> throw new MatchException(null, null);
                    case 0 -> {
                        JitFn jit = (JitFn)selector0$temp;
                        yield jit.invoke(UnaryOperator.identity(), (Seq)argsSeq);
                    }
                    case 1 -> {
                        FnDef.Delegate def = (FnDef.Delegate)selector0$temp;
                        Term core = (Term)((FnDef)def.core()).body().getLeftValue();
                        yield core.instTele(argsSeq.view());
                    }
                };
            } else {
                fnCall = new FnCall(fnDef, 0, (ImmutableSeq)argsSeq);
            }
            return new Jdg.Default((Term)fnCall, result);
        });
    }

    @NotNull
    private Jdg checkClassCall(@NotNull ClassDefLike clazz) throws Ex {
        LocalVar self = LocalVar.generate((String)"self");
        AbstractTele appliedParams = this.ofClassMembers(clazz, this.argsCount).lift(this.lift);
        this.state.classThis.push((Object)self);
        Jdg result = (Jdg)this.makeArgs.applyChecked(appliedParams, (args, term) -> new Jdg.Default((Term)new ClassCall(clazz, 0, ImmutableArray.from((Object[])args).map(x -> x.bind(self))), appliedParams.result(args)));
        this.state.classThis.pop();
        return result;
    }

    @NotNull
    private Jdg checkProjCall(@NotNull MemberDefLike member) throws Ex {
        AbstractTele signature = member.signature().lift(this.lift);
        return (Jdg)this.makeArgs.applyChecked(signature, (args, fstTy) -> {
            assert (((Term[])args).length >= 1);
            Term ofTy = this.whnf((Term)fstTy);
            if (!(ofTy instanceof ClassCall)) {
                throw new UnsupportedOperationException("report");
            }
            ClassCall classTy = (ClassCall)ofTy;
            ImmutableArray fieldArgs = ImmutableArray.fill((int)(((Term[])args).length - 1), i -> args[i + 1]);
            return new Jdg.Default(MemberCall.make((ClassCall)classTy, (Term)args[0], (MemberDefLike)member, (int)0, (ImmutableSeq)fieldArgs), signature.result(args));
        });
    }

    @NotNull
    private AbstractTele ofClassMembers(@NotNull ClassDefLike def, int memberCount) {
        return new TakeMembers(def, memberCount);
    }

    @FunctionalInterface
    public static interface Factory<Ex extends Exception>
    extends CheckedBiFunction<AbstractTele, BiFunction<Term[], Term, Jdg>, Jdg, Ex> {
    }

    record TakeMembers(@NotNull ClassDefLike clazz, int telescopeSize) implements AbstractTele
    {
        public boolean telescopeLicit(int i) {
            return true;
        }

        @NotNull
        public String telescopeName(int i) {
            assert (i < this.telescopeSize);
            return ((MemberDefLike)this.clazz.members().get(i)).name();
        }

        @NotNull
        public Term telescope(int i, Seq<Term> teleArgs) {
            assert (i < this.telescopeSize);
            return this.clazz.telescope(i, teleArgs);
        }

        @NotNull
        public Term result(Seq<Term> teleArgs) {
            return this.clazz.result(this.telescopeSize);
        }

        @NotNull
        public SeqView<String> namesView() {
            return this.clazz.members().sliceView(0, this.telescopeSize).map(AnyDef::name);
        }
    }
}

