/*
 * Decompiled with CFR 0.152.
 */
package io.moderne.ai.research;

import io.moderne.ai.AgentGenerativeModelClient;
import io.moderne.ai.EmbeddingModelClient;
import io.moderne.ai.RelatedModelClient;
import io.moderne.ai.table.CodeSearch;
import io.moderne.ai.table.EmbeddingPerformance;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.PriorityQueue;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import lombok.Generated;
import org.openrewrite.ExecutionContext;
import org.openrewrite.Option;
import org.openrewrite.Preconditions;
import org.openrewrite.Recipe;
import org.openrewrite.ScanningRecipe;
import org.openrewrite.SourceFile;
import org.openrewrite.Tree;
import org.openrewrite.TreeVisitor;
import org.openrewrite.internal.lang.Nullable;
import org.openrewrite.java.JavaIsoVisitor;
import org.openrewrite.java.MethodMatcher;
import org.openrewrite.java.search.UsesMethod;
import org.openrewrite.java.tree.J;
import org.openrewrite.java.tree.JavaSourceFile;
import org.openrewrite.java.tree.JavaType;
import org.openrewrite.java.tree.MethodCall;
import org.openrewrite.marker.SearchResult;

public final class FindCodeThatResembles
extends ScanningRecipe<Accumulator> {
    @Option(displayName="Resembles", description="The text, either a natural language description or a code sample, that you are looking for.", example="HTTP request with Content-Type application/json")
    private final String resembles;
    @Option(displayName="top k methods", description="Since AI based matching has a higher latency than rules based matching, we do a first pass to find the top k methods using embeddings. To narrow the scope, you can specify the top k methods as method filters.", example="1000")
    private final int k;
    private final transient CodeSearch codeSearchTable = new CodeSearch((Recipe)this);
    private final transient EmbeddingPerformance performance = new EmbeddingPerformance((Recipe)this);

    public String getDisplayName() {
        return "Find method invocations that resemble a pattern";
    }

    public String getDescription() {
        return "This recipe uses two phase AI approach to find a method invocation that resembles a search string.";
    }

    public Accumulator getInitialValue(ExecutionContext ctx) {
        return new Accumulator(this.k);
    }

    public TreeVisitor<?, ExecutionContext> getScanner(final Accumulator acc) {
        return new JavaIsoVisitor<ExecutionContext>(){

            private String extractTypeName(String fullyQualifiedTypeName) {
                return fullyQualifiedTypeName.replace("<.*>", "").substring(fullyQualifiedTypeName.lastIndexOf(46) + 1);
            }

            public J.CompilationUnit visitCompilationUnit(J.CompilationUnit cu, ExecutionContext ctx) {
                cu.getTypesInUse().getUsedMethods().forEach(type -> {
                    String methodSignature = this.extractTypeName(Optional.ofNullable(type.getReturnType()).map(Object::toString).orElse("")) + " " + type.getName();
                    CharSequence[] parameters = new String[type.getParameterTypes().size()];
                    for (int i = 0; i < type.getParameterTypes().size(); ++i) {
                        String typeName = this.extractTypeName(((JavaType)type.getParameterTypes().get(i)).toString());
                        String paramName = (String)type.getParameterNames().get(i);
                        parameters[i] = typeName + " " + paramName;
                    }
                    methodSignature = methodSignature + "(" + String.join((CharSequence)", ", parameters) + ")";
                    String methodPattern = Optional.ofNullable(type.getDeclaringType()).map(Object::toString).orElse("") + " " + type.getName() + "(..)";
                    acc.add(methodSignature, methodPattern, FindCodeThatResembles.this.resembles);
                });
                return super.visitCompilationUnit(cu, (Object)ctx);
            }
        };
    }

    public TreeVisitor<?, ExecutionContext> getVisitor(Accumulator acc) {
        final List<MethodMatcher> methodMatchers = acc.getMethodMatchersTopK();
        ArrayList<UsesMethod> preconditions = new ArrayList<UsesMethod>(methodMatchers.size());
        for (MethodMatcher m : methodMatchers) {
            preconditions.add(new UsesMethod(m));
        }
        return Preconditions.check((TreeVisitor)Preconditions.or((TreeVisitor[])preconditions.toArray(new TreeVisitor[0])), (TreeVisitor)new JavaIsoVisitor<ExecutionContext>(){

            public boolean isAcceptable(SourceFile sourceFile, ExecutionContext ctx) {
                return sourceFile instanceof J.CompilationUnit;
            }

            /*
             * WARNING - Removed try catching itself - possible behaviour change.
             */
            public J.CompilationUnit visitCompilationUnit(J.CompilationUnit cu, ExecutionContext ctx) {
                this.getCursor().putMessage("count", (Object)new AtomicInteger());
                this.getCursor().putMessage("max", (Object)new AtomicLong());
                this.getCursor().putMessage("histogram", (Object)new EmbeddingPerformance.Histogram());
                try {
                    J.CompilationUnit compilationUnit = super.visitCompilationUnit(cu, (Object)ctx);
                    return compilationUnit;
                }
                finally {
                    if (((AtomicInteger)this.getCursor().getMessage("count", (Object)new AtomicInteger())).get() > 0) {
                        Duration max = Duration.ofNanos(Objects.requireNonNull((AtomicLong)this.getCursor().getMessage("max")).get());
                        FindCodeThatResembles.this.performance.insertRow(ctx, new EmbeddingPerformance.Row(cu.getSourcePath().toString(), Objects.requireNonNull((AtomicInteger)this.getCursor().getMessage("count")).get(), Objects.requireNonNull((EmbeddingPerformance.Histogram)this.getCursor().getMessage("histogram")).getBuckets(), max));
                    }
                }
            }

            public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) {
                boolean result;
                boolean matches = false;
                for (Object methodMatcher : methodMatchers) {
                    if (!methodMatcher.matches((MethodCall)method)) continue;
                    matches = true;
                    break;
                }
                if (!matches) {
                    return super.visitMethodInvocation(method, (Object)ctx);
                }
                RelatedModelClient.Relatedness related = RelatedModelClient.getInstance().getRelatedness(FindCodeThatResembles.this.resembles, method.printTrimmed(this.getCursor()));
                for (Duration timing : related.getEmbeddingTimings()) {
                    Objects.requireNonNull((AtomicInteger)this.getCursor().getNearestMessage("count")).incrementAndGet();
                    Objects.requireNonNull((EmbeddingPerformance.Histogram)this.getCursor().getNearestMessage("histogram")).add(timing);
                    AtomicLong max = (AtomicLong)this.getCursor().getNearestMessage("max");
                    if (Objects.requireNonNull(max).get() >= timing.toNanos()) continue;
                    max.set(timing.toNanos());
                }
                int resultEmbeddingModels = related.isRelated();
                boolean calledGenerativeModel = false;
                if (resultEmbeddingModels == 0) {
                    result = AgentGenerativeModelClient.getInstance().isRelated(FindCodeThatResembles.this.resembles, method.printTrimmed(this.getCursor()), 0.5932);
                    calledGenerativeModel = true;
                } else {
                    result = resultEmbeddingModels == 1;
                }
                JavaSourceFile javaSourceFile = (JavaSourceFile)this.getCursor().firstEnclosing(JavaSourceFile.class);
                String source = javaSourceFile.getSourcePath().toString();
                if (result || calledGenerativeModel) {
                    FindCodeThatResembles.this.codeSearchTable.insertRow(ctx, new CodeSearch.Row(source, method.printTrimmed(this.getCursor()), FindCodeThatResembles.this.resembles, resultEmbeddingModels, calledGenerativeModel ? (result ? 1 : -1) : 0));
                }
                return result ? (J.MethodInvocation)SearchResult.found((Tree)method) : super.visitMethodInvocation(method, (Object)ctx);
            }
        });
    }

    @Generated
    public FindCodeThatResembles(String resembles, int k) {
        this.resembles = resembles;
        this.k = k;
    }

    @Generated
    public String getResembles() {
        return this.resembles;
    }

    @Generated
    public int getK() {
        return this.k;
    }

    @Generated
    public CodeSearch getCodeSearchTable() {
        return this.codeSearchTable;
    }

    @Generated
    public EmbeddingPerformance getPerformance() {
        return this.performance;
    }

    @Generated
    public String toString() {
        return "FindCodeThatResembles(resembles=" + this.getResembles() + ", k=" + this.getK() + ", codeSearchTable=" + (Object)((Object)this.getCodeSearchTable()) + ", performance=" + (Object)((Object)this.getPerformance()) + ")";
    }

    @Generated
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof FindCodeThatResembles)) {
            return false;
        }
        FindCodeThatResembles other = (FindCodeThatResembles)((Object)o);
        if (!other.canEqual((Object)this)) {
            return false;
        }
        if (this.getK() != other.getK()) {
            return false;
        }
        String this$resembles = this.getResembles();
        String other$resembles = other.getResembles();
        return !(this$resembles == null ? other$resembles != null : !this$resembles.equals(other$resembles));
    }

    @Generated
    protected boolean canEqual(Object other) {
        return other instanceof FindCodeThatResembles;
    }

    @Generated
    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        result = result * 59 + this.getK();
        String $resembles = this.getResembles();
        result = result * 59 + ($resembles == null ? 43 : $resembles.hashCode());
        return result;
    }

    public static final class Accumulator {
        private final int k;
        private final PriorityQueue<MethodSignatureWithDistance> methodSignaturesQueue = new PriorityQueue<MethodSignatureWithDistance>(Comparator.comparingDouble(MethodSignatureWithDistance::getDistance));
        private final EmbeddingModelClient embeddingModelClient = EmbeddingModelClient.getInstance();
        @Nullable
        private List<MethodMatcher> topMethodPatterns;

        public void add(String methodSignature, String methodPattern, String resembles) {
            for (MethodSignatureWithDistance entry : this.methodSignaturesQueue) {
                if (!entry.methodPattern.equals(methodPattern)) continue;
                return;
            }
            MethodSignatureWithDistance methodSignatureWithDistance = new MethodSignatureWithDistance(methodSignature, methodPattern, (float)this.embeddingModelClient.getDistance(resembles, methodSignature));
            this.methodSignaturesQueue.add(methodSignatureWithDistance);
        }

        public List<MethodMatcher> getMethodMatchersTopK() {
            if (this.topMethodPatterns != null) {
                return this.topMethodPatterns;
            }
            this.topMethodPatterns = new ArrayList<MethodMatcher>(this.k);
            for (int i = 0; i < this.k && !this.methodSignaturesQueue.isEmpty(); ++i) {
                String inputString = this.methodSignaturesQueue.poll().getMethodPattern();
                if (!inputString.contains("<constructor>")) {
                    inputString = inputString.replaceAll("<[^>]*>", "");
                }
                this.topMethodPatterns.add(new MethodMatcher(inputString, true));
            }
            return this.topMethodPatterns;
        }

        @Generated
        public int getK() {
            return this.k;
        }

        @Generated
        public PriorityQueue<MethodSignatureWithDistance> getMethodSignaturesQueue() {
            return this.methodSignaturesQueue;
        }

        @Generated
        public EmbeddingModelClient getEmbeddingModelClient() {
            return this.embeddingModelClient;
        }

        @Generated
        public List<MethodMatcher> getTopMethodPatterns() {
            return this.topMethodPatterns;
        }

        @Generated
        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof Accumulator)) {
                return false;
            }
            Accumulator other = (Accumulator)o;
            if (this.getK() != other.getK()) {
                return false;
            }
            PriorityQueue<MethodSignatureWithDistance> this$methodSignaturesQueue = this.getMethodSignaturesQueue();
            PriorityQueue<MethodSignatureWithDistance> other$methodSignaturesQueue = other.getMethodSignaturesQueue();
            if (this$methodSignaturesQueue == null ? other$methodSignaturesQueue != null : !this$methodSignaturesQueue.equals(other$methodSignaturesQueue)) {
                return false;
            }
            EmbeddingModelClient this$embeddingModelClient = this.getEmbeddingModelClient();
            EmbeddingModelClient other$embeddingModelClient = other.getEmbeddingModelClient();
            if (this$embeddingModelClient == null ? other$embeddingModelClient != null : !this$embeddingModelClient.equals(other$embeddingModelClient)) {
                return false;
            }
            List<MethodMatcher> this$topMethodPatterns = this.getTopMethodPatterns();
            List<MethodMatcher> other$topMethodPatterns = other.getTopMethodPatterns();
            return !(this$topMethodPatterns == null ? other$topMethodPatterns != null : !((Object)this$topMethodPatterns).equals(other$topMethodPatterns));
        }

        @Generated
        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            result = result * 59 + this.getK();
            PriorityQueue<MethodSignatureWithDistance> $methodSignaturesQueue = this.getMethodSignaturesQueue();
            result = result * 59 + ($methodSignaturesQueue == null ? 43 : $methodSignaturesQueue.hashCode());
            EmbeddingModelClient $embeddingModelClient = this.getEmbeddingModelClient();
            result = result * 59 + ($embeddingModelClient == null ? 43 : $embeddingModelClient.hashCode());
            List<MethodMatcher> $topMethodPatterns = this.getTopMethodPatterns();
            result = result * 59 + ($topMethodPatterns == null ? 43 : ((Object)$topMethodPatterns).hashCode());
            return result;
        }

        @Generated
        public String toString() {
            return "FindCodeThatResembles.Accumulator(k=" + this.getK() + ", methodSignaturesQueue=" + this.getMethodSignaturesQueue() + ", embeddingModelClient=" + this.getEmbeddingModelClient() + ", topMethodPatterns=" + this.getTopMethodPatterns() + ")";
        }

        @Generated
        public Accumulator(int k) {
            this.k = k;
        }
    }

    private static final class MethodSignatureWithDistance {
        private final String methodSignature;
        private final String methodPattern;
        private final double distance;

        @Generated
        public MethodSignatureWithDistance(String methodSignature, String methodPattern, double distance) {
            this.methodSignature = methodSignature;
            this.methodPattern = methodPattern;
            this.distance = distance;
        }

        @Generated
        public String getMethodSignature() {
            return this.methodSignature;
        }

        @Generated
        public String getMethodPattern() {
            return this.methodPattern;
        }

        @Generated
        public double getDistance() {
            return this.distance;
        }

        @Generated
        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof MethodSignatureWithDistance)) {
                return false;
            }
            MethodSignatureWithDistance other = (MethodSignatureWithDistance)o;
            if (Double.compare(this.getDistance(), other.getDistance()) != 0) {
                return false;
            }
            String this$methodSignature = this.getMethodSignature();
            String other$methodSignature = other.getMethodSignature();
            if (this$methodSignature == null ? other$methodSignature != null : !this$methodSignature.equals(other$methodSignature)) {
                return false;
            }
            String this$methodPattern = this.getMethodPattern();
            String other$methodPattern = other.getMethodPattern();
            return !(this$methodPattern == null ? other$methodPattern != null : !this$methodPattern.equals(other$methodPattern));
        }

        @Generated
        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            long $distance = Double.doubleToLongBits(this.getDistance());
            result = result * 59 + (int)($distance >>> 32 ^ $distance);
            String $methodSignature = this.getMethodSignature();
            result = result * 59 + ($methodSignature == null ? 43 : $methodSignature.hashCode());
            String $methodPattern = this.getMethodPattern();
            result = result * 59 + ($methodPattern == null ? 43 : $methodPattern.hashCode());
            return result;
        }

        @Generated
        public String toString() {
            return "FindCodeThatResembles.MethodSignatureWithDistance(methodSignature=" + this.getMethodSignature() + ", methodPattern=" + this.getMethodPattern() + ", distance=" + this.getDistance() + ")";
        }
    }
}

