/*
 * Decompiled with CFR 0.152.
 */
package org.aya.core.pat;

import java.lang.runtime.SwitchBootstraps;
import java.util.Objects;
import kala.collection.SeqLike;
import kala.collection.immutable.ImmutableSeq;
import kala.collection.mutable.MutableHashMap;
import kala.collection.mutable.MutableMap;
import kala.tuple.Tuple2;
import org.aya.api.ref.LocalVar;
import org.aya.api.ref.Var;
import org.aya.api.util.Arg;
import org.aya.core.def.PrimDef;
import org.aya.core.pat.Pat;
import org.aya.core.term.CallTerm;
import org.aya.core.term.IntroTerm;
import org.aya.core.term.Term;
import org.aya.core.visitor.Substituter;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

public record PatMatcher(@NotNull Substituter.TermSubst subst) {
    @Nullable
    public static Substituter.TermSubst tryBuildSubstArgs(@NotNull @NotNull ImmutableSeq<@NotNull Pat> pats, @NotNull @NotNull SeqLike<@NotNull Arg<@NotNull Term>> terms) {
        return PatMatcher.tryBuildSubstTerms(pats, (SeqLike<Term>)terms.view().map(Arg::term));
    }

    @Nullable
    public static Substituter.TermSubst tryBuildSubstTerms(@NotNull @NotNull ImmutableSeq<@NotNull Pat> pats, @NotNull @NotNull SeqLike<@NotNull Term> terms) {
        PatMatcher matchy = new PatMatcher(new Substituter.TermSubst((MutableMap<Var, Term>)new MutableHashMap()));
        try {
            for (Tuple2 pat : pats.zip(terms)) {
                matchy.match((Tuple2<Pat, Term>)pat);
            }
            return matchy.subst();
        }
        catch (Mismatch mismatch) {
            return null;
        }
    }

    private void match(@NotNull Pat pat, @NotNull Term term) throws Mismatch {
        Pat pat2 = pat;
        Objects.requireNonNull(pat2);
        Pat pat3 = pat2;
        int n = 0;
        switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{Pat.Bind.class, Pat.Absurd.class, Pat.Prim.class, Pat.Ctor.class, Pat.Tuple.class}, (Object)pat3, n)) {
            default: {
                throw new IncompatibleClassChangeError();
            }
            case 0: {
                Pat.Bind bind = (Pat.Bind)pat3;
                this.subst.addDirectly((Var)bind.as(), term);
                break;
            }
            case 1: {
                Pat.Absurd absurd = (Pat.Absurd)pat3;
                throw new Mismatch();
            }
            case 2: {
                CallTerm.Prim primCall;
                Pat.Prim prim = (Pat.Prim)pat3;
                PrimDef core = (PrimDef)prim.ref().core;
                assert (PrimDef.Factory.INSTANCE.leftOrRight(core));
                if (term instanceof CallTerm.Prim && (primCall = (CallTerm.Prim)term).ref() == prim.ref()) break;
                throw new Mismatch();
            }
            case 3: {
                Pat.Ctor ctor = (Pat.Ctor)pat3;
                if (!(term instanceof CallTerm.Con)) {
                    throw new Mismatch();
                }
                CallTerm.Con conCall = (CallTerm.Con)term;
                LocalVar as = ctor.as();
                if (as != null) {
                    this.subst.addDirectly((Var)as, conCall);
                }
                if (ctor.ref() != conCall.ref()) {
                    throw new Mismatch();
                }
                this.visitList(ctor.params(), (SeqLike<Term>)conCall.conArgs().view().map(Arg::term));
                break;
            }
            case 4: {
                Pat.Tuple tuple = (Pat.Tuple)pat3;
                if (!(term instanceof IntroTerm.Tuple)) {
                    throw new Mismatch();
                }
                IntroTerm.Tuple tup = (IntroTerm.Tuple)term;
                LocalVar as = tuple.as();
                if (as != null) {
                    this.subst.addDirectly((Var)as, tup);
                }
                this.visitList(tuple.pats(), (SeqLike<Term>)tup.items());
            }
        }
    }

    private void visitList(ImmutableSeq<Pat> lpats, SeqLike<Term> terms) throws Mismatch {
        assert (lpats.sizeEquals(terms));
        lpats.view().zip(terms).forEachChecked(this::match);
    }

    private void match(@NotNull Tuple2<Pat, Term> pp) throws Mismatch {
        this.match((Pat)pp._1, (Term)pp._2);
    }

    private static final class Mismatch
    extends Exception {
        private Mismatch() {
        }
    }
}

