/*
 * Decompiled with CFR 0.152.
 */
package org.noear.solon.ai.rag.util;

import java.io.IOException;
import java.util.Comparator;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.noear.solon.ai.rag.Document;
import org.noear.solon.ai.rag.util.QueryCondition;

public final class SimilarityUtil {
    public static List<Document> refilter(Stream<Document> docs) {
        return SimilarityUtil.refilter(docs, 4);
    }

    public static List<Document> refilter(Stream<Document> docs, int limit) {
        return SimilarityUtil.refilter(docs, limit, 0.4);
    }

    public static List<Document> refilter(Stream<Document> docs, int limit, double similarityThreshold) {
        return docs.filter(doc -> SimilarityUtil.similarityCheck(doc, similarityThreshold)).sorted(Comparator.comparing(Document::getScore).reversed()).limit(limit).collect(Collectors.toList());
    }

    public static List<Document> refilter(Stream<Document> docs, QueryCondition condition) throws IOException {
        if (condition.isDisableRefilter()) {
            return SimilarityUtil.refilter(docs, condition.getLimit(), condition.getSimilarityThreshold());
        }
        return SimilarityUtil.refilter(docs.filter(condition::doFilter), condition.getLimit(), condition.getSimilarityThreshold());
    }

    public static Document score(Document doc, float[] queryEmbed) {
        return doc.score(SimilarityUtil.cosineSimilarity(queryEmbed, doc.getEmbedding()));
    }

    public static Document copyAndScore(Document doc, float[] queryEmbed) {
        return new Document(doc.getId(), doc.getContent(), doc.getMetadata(), SimilarityUtil.cosineSimilarity(queryEmbed, doc.getEmbedding()));
    }

    public static boolean similarityCheck(Document doc, double similarityThreshold) {
        return doc.getScore() >= similarityThreshold;
    }

    private static double cosineSimilarity(float[] embedA, float[] embedB) {
        if (embedA != null && embedB != null) {
            if (embedA.length != embedB.length) {
                throw new IllegalArgumentException("Embed length must be equal");
            }
            float dotProduct = SimilarityUtil.dotProduct(embedA, embedB);
            float normA = SimilarityUtil.norm(embedA);
            float normB = SimilarityUtil.norm(embedB);
            if (normA != 0.0f && normB != 0.0f) {
                return (double)dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
            }
            throw new IllegalArgumentException("Embed cannot be zero norm");
        }
        throw new RuntimeException("Embed must not be null");
    }

    private static float dotProduct(float[] embedA, float[] embedB) {
        if (embedA.length != embedB.length) {
            throw new IllegalArgumentException("Embed length must be equal");
        }
        float tmp = 0.0f;
        for (int i = 0; i < embedA.length; ++i) {
            tmp += embedA[i] * embedB[i];
        }
        return tmp;
    }

    private static float norm(float[] vector) {
        return SimilarityUtil.dotProduct(vector, vector);
    }
}

