/*
 * Decompiled with CFR 0.152.
 */
package org.aya.terck;

import java.lang.runtime.SwitchBootstraps;
import java.util.Objects;
import kala.collection.SeqLike;
import kala.collection.SeqView;
import kala.collection.immutable.ImmutableSeq;
import kala.collection.mutable.MutableSet;
import kala.tuple.Tuple2;
import kala.tuple.Unit;
import kala.value.Ref;
import org.aya.core.Matching;
import org.aya.core.def.Def;
import org.aya.core.def.FnDef;
import org.aya.core.pat.Pat;
import org.aya.core.term.CallTerm;
import org.aya.core.term.ElimTerm;
import org.aya.core.term.RefTerm;
import org.aya.core.term.Term;
import org.aya.core.visitor.DefConsumer;
import org.aya.generic.Arg;
import org.aya.ref.DefVar;
import org.aya.ref.Var;
import org.aya.terck.CallGraph;
import org.aya.terck.CallMatrix;
import org.aya.terck.Relation;
import org.jetbrains.annotations.NotNull;

public record CallResolver(@NotNull FnDef caller, @NotNull MutableSet<Def> targets, @NotNull Ref<Matching> currentMatching) implements DefConsumer<CallGraph<Def, Term.Param>>
{
    public CallResolver(@NotNull FnDef fn, @NotNull MutableSet<Def> targets) {
        this(fn, targets, (Ref<Matching>)new Ref());
    }

    private void resolveCall(@NotNull CallTerm callTerm, CallGraph<Def, Term.Param> graph) {
        Var var = callTerm.ref();
        if (!(var instanceof DefVar)) {
            return;
        }
        DefVar defVar = (DefVar)var;
        Object callee = defVar.core;
        if (!this.targets.contains(callee)) {
            return;
        }
        CallMatrix<Def, Term.Param> matrix = new CallMatrix<Def, Term.Param>(callTerm, this.caller, (Def)callee, (ImmutableSeq<Term.Param>)this.caller.telescope, callee.telescope());
        this.fillMatrix(callTerm, (Def)callee, matrix);
        graph.put(matrix);
    }

    private void fillMatrix(@NotNull CallTerm callTerm, @NotNull Def callee, CallMatrix<Def, Term.Param> matrix) {
        Matching matching = (Matching)this.currentMatching.value;
        if (matching == null) {
            return;
        }
        for (Tuple2 domThing : matching.patterns().zipView((SeqLike)this.caller.telescope)) {
            for (Tuple2 codomThing : callTerm.args().zipView(callee.telescope())) {
                Relation relation = this.compare((Term)((Arg)codomThing._1).term(), (Pat)domThing._1);
                matrix.set((Term.Param)domThing._2, (Term.Param)codomThing._2, relation);
            }
        }
    }

    @NotNull
    private Relation compare(@NotNull Term lhs, @NotNull Pat rhs) {
        if (rhs instanceof Pat.Ctor) {
            Pat.Ctor ctor = (Pat.Ctor)rhs;
            if (lhs instanceof CallTerm.Con) {
                CallTerm.Con con = (CallTerm.Con)lhs;
                if (con.ref() != ctor.ref()) {
                    return Relation.Unknown;
                }
                if (ctor.params().isEmpty()) {
                    return Relation.Equal;
                }
                SeqView subCompare = con.conArgs().zipView(ctor.params()).map(sub -> this.compare((Term)((Arg)sub._1).term(), (Pat)sub._2));
                return (Relation)subCompare.max();
            }
            SeqView subCompare = ctor.params().view().map(sub -> this.compare(lhs, (Pat)sub));
            return subCompare.anyMatch(r -> r != Relation.Unknown) ? Relation.LessThan : Relation.Unknown;
        }
        if (rhs instanceof Pat.Bind) {
            Pat.Bind bind = (Pat.Bind)rhs;
            if (lhs instanceof RefTerm) {
                RefTerm ref = (RefTerm)lhs;
                return ref.var() == bind.bind() ? Relation.Equal : Relation.Unknown;
            }
            Term term = this.headOf(lhs);
            if (term instanceof RefTerm) {
                RefTerm ref = (RefTerm)term;
                return ref.var() == bind.bind() ? Relation.LessThan : Relation.Unknown;
            }
        }
        return Relation.Unknown;
    }

    @NotNull
    private Term headOf(@NotNull Term term) {
        Term term2 = term;
        Objects.requireNonNull(term2);
        Term term3 = term2;
        int n = 0;
        return switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{ElimTerm.App.class, ElimTerm.Proj.class, CallTerm.Access.class}, (Object)term3, n)) {
            case 0 -> {
                ElimTerm.App app = (ElimTerm.App)term3;
                yield this.headOf(app.of());
            }
            case 1 -> {
                ElimTerm.Proj proj = (ElimTerm.Proj)term3;
                yield this.headOf(proj.of());
            }
            case 2 -> {
                CallTerm.Access access = (CallTerm.Access)term3;
                yield this.headOf(access.of());
            }
            default -> term;
        };
    }

    @Override
    public void visitMatching(@NotNull Matching matching, CallGraph<Def, Term.Param> graph) {
        this.currentMatching.value = matching;
        DefConsumer.super.visitMatching(matching, graph);
        this.currentMatching.value = null;
    }

    @Override
    public Unit visitFnCall(@NotNull CallTerm.Fn fnCall, CallGraph<Def, Term.Param> graph) {
        this.resolveCall(fnCall, graph);
        return DefConsumer.super.visitFnCall(fnCall, graph);
    }

    @Override
    public Unit visitConCall(@NotNull CallTerm.Con conCall, CallGraph<Def, Term.Param> graph) {
        this.resolveCall(conCall, graph);
        return DefConsumer.super.visitConCall(conCall, graph);
    }

    @Override
    public Unit visitDataCall(@NotNull CallTerm.Data dataCall, CallGraph<Def, Term.Param> graph) {
        this.resolveCall(dataCall, graph);
        return DefConsumer.super.visitDataCall(dataCall, graph);
    }

    @Override
    public Unit visitStructCall(@NotNull CallTerm.Struct structCall, CallGraph<Def, Term.Param> graph) {
        this.resolveCall(structCall, graph);
        return DefConsumer.super.visitStructCall(structCall, graph);
    }

    @Override
    public Unit visitAccess(@NotNull CallTerm.Access term, CallGraph<Def, Term.Param> defCallGraph) {
        this.resolveCall(term, defCallGraph);
        return DefConsumer.super.visitAccess(term, defCallGraph);
    }

    @Override
    public Unit visitPrimCall(@NotNull CallTerm.Prim prim, CallGraph<Def, Term.Param> graph) {
        this.resolveCall(prim, graph);
        return DefConsumer.super.visitPrimCall(prim, graph);
    }
}

