/*
 * Decompiled with CFR 0.152.
 */
package org.aya.core.def;

import java.util.EnumMap;
import java.util.Objects;
import java.util.function.BiFunction;
import java.util.function.Function;
import kala.collection.Map;
import kala.collection.immutable.ImmutableSeq;
import kala.control.Option;
import kala.tuple.Tuple;
import kala.tuple.Tuple2;
import org.aya.concrete.stmt.Decl;
import org.aya.core.def.Def;
import org.aya.core.def.TopLevelDef;
import org.aya.core.term.CallTerm;
import org.aya.core.term.ElimTerm;
import org.aya.core.term.FormTerm;
import org.aya.core.term.IntroTerm;
import org.aya.core.term.RefTerm;
import org.aya.core.term.Term;
import org.aya.generic.Arg;
import org.aya.generic.util.NormalizeMode;
import org.aya.ref.DefVar;
import org.aya.ref.LocalVar;
import org.aya.tyck.TyckState;
import org.jetbrains.annotations.NonNls;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

public final class PrimDef
extends TopLevelDef {
    @NotNull
    public final @NotNull DefVar<@NotNull PrimDef, Decl.PrimDecl> ref;
    @NotNull
    public final ID id;

    public PrimDef(@NotNull @NotNull DefVar<@NotNull PrimDef, @NotNull Decl.PrimDecl> ref, @NotNull ImmutableSeq<Term.Param> telescope, @NotNull Term result, @NotNull ID name) {
        super(telescope, result);
        this.ref = ref;
        this.id = name;
        ref.core = this;
    }

    public PrimDef(@NotNull @NotNull DefVar<@NotNull PrimDef, @NotNull Decl.PrimDecl> ref, @NotNull Term result, @NotNull ID name) {
        this(ref, (ImmutableSeq<Term.Param>)ImmutableSeq.empty(), result, name);
    }

    @NotNull
    public static CallTerm.Prim intervalCall() {
        return new CallTerm.Prim(((PrimDef)Factory.INSTANCE.getOption(ID.INTERVAL).get()).ref(), 0, (ImmutableSeq<Arg<Term>>)ImmutableSeq.empty());
    }

    @Override
    public <P, R> R accept(@NotNull Def.Visitor<P, R> visitor, P p) {
        return visitor.visitPrim(this, p);
    }

    @NotNull
    public Term unfold(@NotNull CallTerm.Prim primCall, @Nullable TyckState state) {
        return Factory.INSTANCE.unfold(Objects.requireNonNull(ID.find(this.ref.name())), primCall, state);
    }

    @Override
    @NotNull
    public ImmutableSeq<Term.Param> telescope() {
        Def.Signature signature;
        if (this.telescope.isEmpty()) {
            return this.telescope;
        }
        if (this.ref.concrete != null && (signature = ((Decl.PrimDecl)this.ref.concrete).signature) != null) {
            return signature.param();
        }
        return this.telescope;
    }

    @Override
    @NotNull
    public Term result() {
        Def.Signature signature;
        if (this.ref.concrete != null && (signature = ((Decl.PrimDecl)this.ref.concrete).signature) != null) {
            return signature.result();
        }
        return this.result;
    }

    @NotNull
    public @NotNull DefVar<@NotNull PrimDef, Decl.PrimDecl> ref() {
        return this.ref;
    }

    public static enum ID {
        INTERVAL("I"),
        LEFT("left"),
        RIGHT("right"),
        ARCOE("arcoe"),
        SQUEEZE_LEFT("squeezeL"),
        INVOL("invol");

        @NotNull
        @NonNls
        public final String id;

        public String toString() {
            return this.id;
        }

        @Nullable
        public static ID find(@NotNull String id) {
            for (ID value : ID.values()) {
                if (!Objects.equals(value.id, id)) continue;
                return value;
            }
            return null;
        }

        private ID(String id) {
            this.id = id;
        }
    }

    public static class Factory {
        @NotNull
        private final @NotNull EnumMap<@NotNull ID, @NotNull PrimDef> defs = new EnumMap(ID.class);
        @NotNull
        public static final Factory INSTANCE = new Factory();
        @NotNull
        private static final @NotNull Map<@NotNull ID, @NotNull PrimSeed> SEEDS = ImmutableSeq.of((Object[])new PrimSeed[]{PrimSeed.INTERVAL, PrimSeed.LEFT, PrimSeed.RIGHT, PrimSeed.ARCOE, PrimSeed.SQUEEZE_LEFT, PrimSeed.INVOL}).map(seed -> Tuple.of((Object)((Object)seed.name), (Object)seed)).toImmutableMap();
        @NotNull
        public static final ImmutableSeq<ID> LEFT_RIGHT = ImmutableSeq.of((Object)((Object)ID.LEFT), (Object)((Object)ID.RIGHT));

        private Factory() {
        }

        @NotNull
        public PrimDef factory(@NotNull ID name, @NotNull DefVar<PrimDef, Decl.PrimDecl> ref) {
            assert (!this.have(name));
            PrimDef rst = ((PrimSeed)SEEDS.get((Object)name)).supply(ref);
            this.defs.put(name, rst);
            return rst;
        }

        @NotNull
        public Option<PrimDef> getOption(@NotNull ID name) {
            return Option.of((Object)this.defs.get((Object)name));
        }

        public boolean have(@NotNull ID name) {
            return this.defs.containsKey((Object)name);
        }

        @NotNull
        public PrimDef getOrCreate(@NotNull ID name, @NotNull DefVar<PrimDef, Decl.PrimDecl> ref) {
            return (PrimDef)this.getOption(name).getOrElse(() -> this.factory(name, ref));
        }

        @NotNull
        public @NotNull Option<ImmutableSeq<@NotNull ID>> checkDependency(@NotNull ID name) {
            return SEEDS.getOption((Object)name).map(seed -> seed.dependency().filterNot(this::have));
        }

        @NotNull
        public Term unfold(@NotNull ID name, @NotNull CallTerm.Prim primCall, @Nullable TyckState state) {
            return ((PrimSeed)Factory.SEEDS.get((Object)((Object)name))).unfold.apply(primCall, state);
        }

        public boolean leftOrRight(PrimDef core) {
            for (ID primName : LEFT_RIGHT) {
                Option<PrimDef> cur = this.getOption(primName);
                if (!cur.isNotEmpty() || core != cur.get()) continue;
                return true;
            }
            return false;
        }

        public void clear() {
            this.defs.clear();
        }
    }

    record PrimSeed(@NotNull ID name, @NotNull @NotNull BiFunction< @NotNull CallTerm.Prim, @Nullable TyckState, @NotNull Term> unfold, @NotNull @NotNull Function<@NotNull DefVar<PrimDef, Decl.PrimDecl>, @NotNull PrimDef> supplier, @NotNull @NotNull ImmutableSeq<@NotNull ID> dependency) {
        @NotNull
        public static final PrimSeed INTERVAL = new PrimSeed(ID.INTERVAL, (prim, state) -> prim, ref -> new PrimDef((DefVar<PrimDef, Decl.PrimDecl>)ref, FormTerm.Univ.ZERO, ID.INTERVAL), (ImmutableSeq<ID>)ImmutableSeq.empty());
        @NotNull
        public static final PrimSeed LEFT = new PrimSeed(ID.LEFT, (prim, state) -> prim, ref -> new PrimDef((DefVar<PrimDef, Decl.PrimDecl>)ref, PrimDef.intervalCall(), ID.LEFT), (ImmutableSeq<ID>)ImmutableSeq.of((Object)((Object)ID.INTERVAL)));
        @NotNull
        public static final PrimSeed RIGHT = new PrimSeed(ID.RIGHT, (prim, state) -> prim, ref -> new PrimDef((DefVar<PrimDef, Decl.PrimDecl>)ref, PrimDef.intervalCall(), ID.RIGHT), (ImmutableSeq<ID>)ImmutableSeq.of((Object)((Object)ID.INTERVAL)));
        @NotNull
        public static final PrimSeed ARCOE = new PrimSeed(ID.ARCOE, PrimSeed::arcoe, ref -> {
            LocalVar paramA = new LocalVar("A");
            Term.Param paramIToATy = new Term.Param(new LocalVar("_"), PrimDef.intervalCall(), true);
            LocalVar paramI = new LocalVar("i");
            FormTerm.Univ result = new FormTerm.Univ(0);
            FormTerm.Pi paramATy = new FormTerm.Pi(paramIToATy, result);
            RefTerm aRef = new RefTerm(paramA, 0);
            PrimDef left = (PrimDef)Factory.INSTANCE.getOption(ID.LEFT).get();
            ElimTerm.App baseAtLeft = new ElimTerm.App(aRef, new Arg<Term>(new CallTerm.Prim(left.ref, 0, (ImmutableSeq<Arg<Term>>)ImmutableSeq.empty()), true));
            return new PrimDef((DefVar<PrimDef, Decl.PrimDecl>)ref, (ImmutableSeq<Term.Param>)ImmutableSeq.of((Object)new Term.Param(paramA, paramATy, true), (Object)new Term.Param(new LocalVar("base"), baseAtLeft, true), (Object)new Term.Param(paramI, PrimDef.intervalCall(), true)), new ElimTerm.App(aRef, new Arg<Term>(new RefTerm(paramI, 0), true)), ID.ARCOE);
        }, (ImmutableSeq<ID>)ImmutableSeq.of((Object)((Object)ID.INTERVAL), (Object)((Object)ID.LEFT)));
        @NotNull
        public static final PrimSeed INVOL = new PrimSeed(ID.INVOL, PrimSeed::invol, ref -> new PrimDef((DefVar<PrimDef, Decl.PrimDecl>)ref, (ImmutableSeq<Term.Param>)ImmutableSeq.of((Object)new Term.Param(new LocalVar("i"), PrimDef.intervalCall(), true)), PrimDef.intervalCall(), ID.INVOL), (ImmutableSeq<ID>)ImmutableSeq.of((Object)((Object)ID.INTERVAL)));
        @NotNull
        public static final PrimSeed SQUEEZE_LEFT = new PrimSeed(ID.SQUEEZE_LEFT, PrimSeed::squeezeLeft, ref -> new PrimDef((DefVar<PrimDef, Decl.PrimDecl>)ref, (ImmutableSeq<Term.Param>)ImmutableSeq.of((Object)new Term.Param(new LocalVar("i"), PrimDef.intervalCall(), true), (Object)new Term.Param(new LocalVar("j"), PrimDef.intervalCall(), true)), PrimDef.intervalCall(), ID.SQUEEZE_LEFT), (ImmutableSeq<ID>)ImmutableSeq.of((Object)((Object)ID.INTERVAL)));

        @NotNull
        public PrimDef supply(@NotNull DefVar<PrimDef, Decl.PrimDecl> ref) {
            return this.supplier.apply(ref);
        }

        @NotNull
        private static Term arcoe( @NotNull CallTerm.Prim prim, @Nullable TyckState state) {
            Term argA;
            ImmutableSeq<Arg<Term>> args = prim.args();
            Arg argBase = (Arg)args.get(1);
            Arg argI = (Arg)args.get(2);
            Option<PrimDef> left = Factory.INSTANCE.getOption(ID.LEFT);
            Object t = argI.term();
            if (t instanceof CallTerm.Prim) {
                CallTerm.Prim primCall = (CallTerm.Prim)t;
                if (left.isNotEmpty() && primCall.ref() == ((PrimDef)left.get()).ref) {
                    return (Term)argBase.term();
                }
            }
            if ((argA = (Term)((Arg)args.get(0)).term()) instanceof IntroTerm.Lambda) {
                IntroTerm.Lambda lambda = (IntroTerm.Lambda)argA;
                Term normalize = lambda.body().normalize(state, NormalizeMode.NF);
                if (normalize.findUsages(lambda.param().ref()) == 0) {
                    return (Term)argBase.term();
                }
                return new CallTerm.Prim((DefVar<PrimDef, Decl.PrimDecl>)prim.ref(), prim.ulift(), (ImmutableSeq<Arg<Term>>)ImmutableSeq.of(new Arg<IntroTerm.Lambda>(new IntroTerm.Lambda(lambda.param(), normalize), true), (Object)argBase, (Object)argI));
            }
            return prim;
        }

        @NotNull
        private static Tuple2<PrimDef, PrimDef> leftRight() {
            return Tuple.of((Object)((PrimDef)Factory.INSTANCE.getOption(ID.LEFT).get()), (Object)((PrimDef)Factory.INSTANCE.getOption(ID.RIGHT).get()));
        }

        @NotNull
        private static Term invol( @NotNull CallTerm.Prim prim, @Nullable TyckState state) {
            Term arg = ((Term)((Arg)prim.args().get(0)).term()).normalize(state, NormalizeMode.WHNF);
            if (arg instanceof CallTerm.Prim) {
                CallTerm.Prim primCall = (CallTerm.Prim)arg;
                Tuple2<PrimDef, PrimDef> lr = PrimSeed.leftRight();
                PrimDef left = (PrimDef)lr._1;
                PrimDef right = (PrimDef)lr._2;
                if (primCall.ref() == left.ref) {
                    return new CallTerm.Prim(right.ref, 0, (ImmutableSeq<Arg<Term>>)ImmutableSeq.empty());
                }
                if (primCall.ref() == right.ref) {
                    return new CallTerm.Prim(left.ref, 0, (ImmutableSeq<Arg<Term>>)ImmutableSeq.empty());
                }
            }
            return new CallTerm.Prim((DefVar<PrimDef, Decl.PrimDecl>)prim.ref(), 0, (ImmutableSeq<Arg<Term>>)ImmutableSeq.of(new Arg<Term>(arg, true)));
        }

        @NotNull
        private static Term squeezeLeft( @NotNull CallTerm.Prim prim, @Nullable TyckState state) {
            Term lhsArg = ((Term)((Arg)prim.args().get(0)).term()).normalize(state, NormalizeMode.WHNF);
            Term rhsArg = ((Term)((Arg)prim.args().get(1)).term()).normalize(state, NormalizeMode.WHNF);
            Tuple2<PrimDef, PrimDef> lr = PrimSeed.leftRight();
            PrimDef left = (PrimDef)lr._1;
            PrimDef right = (PrimDef)lr._2;
            if (lhsArg instanceof CallTerm.Prim) {
                CallTerm.Prim lhs = (CallTerm.Prim)lhsArg;
                if (lhs.ref() == left.ref) {
                    return lhs;
                }
                if (lhs.ref() == right.ref) {
                    return rhsArg;
                }
            } else if (rhsArg instanceof CallTerm.Prim) {
                CallTerm.Prim rhs = (CallTerm.Prim)rhsArg;
                if (rhs.ref() == left.ref) {
                    return rhs;
                }
                if (rhs.ref() == right.ref) {
                    return lhsArg;
                }
            }
            return prim;
        }
    }
}

