/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.kernel.api.impl.schema.vector;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import org.assertj.core.api.AbstractCollectionAssert;
import org.assertj.core.api.AbstractDoubleAssert;
import org.assertj.core.api.AbstractFloatArrayAssert;
import org.assertj.core.api.AbstractThrowableAssert;
import org.assertj.core.api.Assertions;
import org.assertj.core.api.InstanceOfAssertFactories;
import org.assertj.core.api.ListAssert;
import org.eclipse.collections.api.factory.Lists;
import org.eclipse.collections.impl.factory.Maps;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.neo4j.configuration.GraphDatabaseInternalSettings;
import org.neo4j.graphdb.GraphDatabaseService;
import org.neo4j.graphdb.Label;
import org.neo4j.graphdb.Node;
import org.neo4j.graphdb.ResourceIterator;
import org.neo4j.graphdb.Result;
import org.neo4j.graphdb.Transaction;
import org.neo4j.internal.helpers.MathUtil;
import org.neo4j.kernel.api.impl.schema.vector.VectorSimilarityFunction;
import org.neo4j.kernel.api.schema.vector.VectorTestUtils;
import org.neo4j.procedure.builtin.VectorIndexProcedures;
import org.neo4j.test.RandomSupport;
import org.neo4j.test.TestDatabaseManagementServiceBuilder;
import org.neo4j.test.extension.ExtensionCallback;
import org.neo4j.test.extension.ImpermanentDbmsExtension;
import org.neo4j.test.extension.Inject;
import org.neo4j.test.extension.RandomExtension;

class VectorIndexProceduresIT {
    VectorIndexProceduresIT() {
    }

    @Nested
    class Cosine
    extends VectorIndexProceduresITBase {
        Cosine() {
            super(VectorSimilarityFunction.COSINE);
        }

        @Test
        void cannotQueryWithZerol2NormVector() {
            this.createIndex();
            float[] query = new float[1000];
            Arrays.fill(query, 0.0f);
            try (Transaction tx = this.db.beginTx();){
                Result results = Cosine.queryNodes(tx, 10, query);
                ((AbstractThrowableAssert)Assertions.assertThatThrownBy(() -> ((Result)results).resultAsString(), (String)"zero l2-norm query vector should throw", (Object[])new Object[0]).rootCause().isInstanceOf(IllegalArgumentException.class)).hasMessageContainingAll(new CharSequence[]{"have positive and finite l2-norm", "Provided"});
            }
        }

        @Test
        void cannotQueryWithNonFinitel2NormVector() {
            this.createIndex();
            float[] query = this.randomVector();
            query[this.random.nextInt((int)query.length)] = Float.MAX_VALUE;
            try (Transaction tx = this.db.beginTx();){
                Result results = Cosine.queryNodes(tx, 10, query);
                ((AbstractThrowableAssert)Assertions.assertThatThrownBy(() -> ((Result)results).resultAsString(), (String)"non-finite l2-norm query vector should throw", (Object[])new Object[0]).rootCause().isInstanceOf(IllegalArgumentException.class)).hasMessageContainingAll(new CharSequence[]{"have positive and finite l2-norm", "Provided"});
            }
        }
    }

    @Nested
    class Euclidean
    extends VectorIndexProceduresITBase {
        Euclidean() {
            super(VectorSimilarityFunction.EUCLIDEAN);
        }
    }

    @ImpermanentDbmsExtension(configurationCallback="configure")
    @ExtendWith(value={RandomExtension.class})
    static abstract class VectorIndexProceduresITBase {
        private static final int NUMBER_OF_NODES = 1000;
        private static final int MAX_PARTITION_SIZE = 400;
        protected static final int VECTOR_DIMENSIONALITY = 1000;
        private static final Label LABEL = Label.label((String)"Vector");
        private static final String PROPERTY_KEY = "vector";
        private static final String INDEX_NAME = "VectorIndex";
        private final VectorSimilarityFunction similarityFunction;
        @Inject
        protected GraphDatabaseService db;
        @Inject
        protected RandomSupport random;

        VectorIndexProceduresITBase(VectorSimilarityFunction similarityFunction) {
            this.similarityFunction = similarityFunction;
        }

        @ExtensionCallback
        void configure(TestDatabaseManagementServiceBuilder builder) {
            builder.setConfig(GraphDatabaseInternalSettings.lucene_max_partition_size, (Object)400);
        }

        @BeforeEach
        void createData() {
            try (Transaction tx = this.db.beginTx();){
                for (int i = 0; i < 1000; ++i) {
                    Node node = tx.createNode(new Label[]{LABEL});
                    node.setProperty(PROPERTY_KEY, (Object)this.randomVector());
                }
                tx.commit();
            }
        }

        @Test
        void testCreateAndQueryIndex() {
            Assertions.assertThatCode(this::createIndex).doesNotThrowAnyException();
            int k = 10;
            float[] query = this.randomVector();
            Assertions.assertThatCode(() -> this.queryNodesAndCollect(10, query)).doesNotThrowAnyException();
        }

        @Test
        void testRecall() {
            this.createIndex();
            int k = 10;
            float[] query = this.randomVector();
            List<VectorIndexProcedures.Neighbor> approximateNearest = this.queryNodesAndCollect(10, query);
            List<VectorIndexProcedures.Neighbor> exactNearest = this.linearSearch(LABEL, PROPERTY_KEY, query, 10);
            double recall = VectorIndexProceduresITBase.recall(approximateNearest, exactNearest);
            ((ListAssert)((ListAssert)Assertions.assertThat(approximateNearest).as("approximate nearest neighbors", new Object[0])).hasSize(10)).isSorted();
            ((AbstractDoubleAssert)Assertions.assertThat((double)recall).as("recall", new Object[0])).isGreaterThan(0.5);
        }

        @Test
        void indexOnlyMatchingDimensions() {
            HashSet<String> nonIndexedVectors = new HashSet<String>();
            try (Transaction tx = this.db.beginTx();
                 ResourceIterator nodes = tx.findNodes(LABEL);){
                while (nodes.hasNext()) {
                    if (this.random.nextFloat() < 0.8f) continue;
                    Node node = (Node)nodes.next();
                    node.setProperty(PROPERTY_KEY, (Object)this.randomVector(999));
                    nonIndexedVectors.add(node.getElementId());
                }
                tx.commit();
            }
            this.createIndex();
            float[] query = this.randomVector();
            HashSet indexedVectors = new HashSet();
            try (Transaction tx = this.db.beginTx();
                 Result results = VectorIndexProceduresITBase.queryNodes(tx, 1000, query);){
                results.accept(row -> indexedVectors.add(row.getNode("node").getElementId()));
            }
            ((AbstractCollectionAssert)Assertions.assertThat(indexedVectors).as("vectors of different dimensions should not be indexed", new Object[0])).doesNotContainAnyElementsOf(nonIndexedVectors);
        }

        @Test
        void cannotQueryWithWrongDimensions() {
            this.createIndex();
            float[] query = this.randomVector(999);
            try (Transaction tx = this.db.beginTx();
                 Result results = VectorIndexProceduresITBase.queryNodes(tx, 10, query);){
                ((AbstractThrowableAssert)Assertions.assertThatThrownBy(() -> ((Result)results).resultAsString(), (String)"incorrectly dimensioned query should throw", (Object[])new Object[0]).rootCause().isInstanceOf(IllegalArgumentException.class)).hasMessageContainingAll(new CharSequence[]{"Index query vector has", String.valueOf(999), "dimensions", "but indexed vectors have", String.valueOf(1000)});
            }
        }

        @Test
        void cannotQueryWithNonFiniteVector() {
            this.createIndex();
            float[] query = this.randomVector();
            query[this.random.nextInt((int)query.length)] = Float.NaN;
            try (Transaction tx = this.db.beginTx();){
                Result results = VectorIndexProceduresITBase.queryNodes(tx, 10, query);
                ((AbstractThrowableAssert)Assertions.assertThatThrownBy(() -> ((Result)results).resultAsString(), (String)"non-finite query vector should throw", (Object[])new Object[0]).rootCause().isInstanceOf(IllegalArgumentException.class)).hasMessageContainingAll(new CharSequence[]{"Index query vector must contain finite values", "Provided"});
            }
        }

        private static Iterable<List<Double>> canSetValidVector() {
            return VectorTestUtils.EUCLIDEAN_VALID_VECTORS_FROM_DOUBLE_LIST;
        }

        @ParameterizedTest
        @MethodSource
        void canSetValidVector(List<Double> candidate) {
            long id;
            Node node;
            try (Transaction tx = this.db.beginTx();){
                node = tx.createNode(new Label[]{LABEL});
                id = node.getId();
                try (Result result = VectorIndexProceduresITBase.setVectorProperty(tx, node, candidate);){
                    Assertions.assertThatCode(() -> VectorIndexProceduresITBase.collectNodes(result)).doesNotThrowAnyException();
                }
                tx.commit();
            }
            tx = this.db.beginTx();
            try {
                node = tx.getNodeById(id);
                Assertions.assertThat((Iterable)node.getPropertyKeys()).contains((Object[])new String[]{PROPERTY_KEY});
                Object vector = node.getProperty(PROPERTY_KEY);
                ((AbstractFloatArrayAssert)Assertions.assertThat((Object)vector).asInstanceOf(InstanceOfAssertFactories.FLOAT_ARRAY)).containsExactly(Lists.immutable.ofAll(candidate).collectFloat(Double::floatValue).toArray());
            }
            finally {
                if (tx != null) {
                    tx.close();
                }
            }
        }

        private static Iterable<List<Double>> cannotSetInvalidVector() {
            return VectorTestUtils.EUCLIDEAN_INVALID_VECTORS_FROM_DOUBLE_LIST;
        }

        @ParameterizedTest
        @MethodSource
        void cannotSetInvalidVector(List<Double> candidate) {
            try (Transaction tx = this.db.beginTx();){
                Node node = tx.createNode(new Label[]{LABEL});
                AbstractThrowableAssert rootAssert = Assertions.assertThatThrownBy(() -> VectorIndexProceduresITBase.setVectorProperty(tx, node, candidate), (String)"invalid vectors should throw", (Object[])new Object[0]).rootCause();
                if (candidate != null) {
                    ((AbstractThrowableAssert)rootAssert.isInstanceOf(IllegalArgumentException.class)).hasMessageContainingAll(new CharSequence[]{"Index query vector must contain finite values", "Provided"});
                } else {
                    ((AbstractThrowableAssert)rootAssert.isInstanceOf(NullPointerException.class)).hasMessageContainingAll(new CharSequence[]{PROPERTY_KEY, "must not be null"});
                }
            }
        }

        @Test
        void cannotReturnMoreThanMax() {
            this.createIndex();
            int k = MathUtil.ceil((int)12000, (int)10);
            float[] query = this.randomVector();
            List<VectorIndexProcedures.Neighbor> approximateNearest = this.queryNodesAndCollect(k, query);
            ((ListAssert)Assertions.assertThat(approximateNearest).as("approximate nearest neighbors", new Object[0])).hasSizeLessThanOrEqualTo(k);
        }

        @Test
        void testLinearSearch() {
            int k = 10;
            float[] query = this.randomVector();
            List<VectorIndexProcedures.Neighbor> nearest = this.linearSearch(LABEL, PROPERTY_KEY, query, 10);
            ((ListAssert)Assertions.assertThat(nearest).hasSize(10)).isSorted();
        }

        protected void createIndex() {
            this.db.executeTransactionally("CALL db.index.vector.createNodeIndex($name, $label, $propertyKey, $dimensions, $similarity)", Map.of("name", INDEX_NAME, "label", LABEL.name(), "propertyKey", PROPERTY_KEY, "dimensions", 1000, "similarity", this.similarityFunction.name().toLowerCase(Locale.ROOT)));
        }

        private List<VectorIndexProcedures.Neighbor> queryNodesAndCollect(int k, float[] query) {
            try (Transaction tx = this.db.beginTx();){
                List<VectorIndexProcedures.Neighbor> list;
                block12: {
                    Result results = VectorIndexProceduresITBase.queryNodes(tx, k, query);
                    try {
                        list = VectorIndexProceduresITBase.collectNeighbors(results);
                        if (results == null) break block12;
                    }
                    catch (Throwable throwable) {
                        if (results != null) {
                            try {
                                results.close();
                            }
                            catch (Throwable throwable2) {
                                throwable.addSuppressed(throwable2);
                            }
                        }
                        throw throwable;
                    }
                    results.close();
                }
                return list;
            }
        }

        protected static Result queryNodes(Transaction tx, int k, float[] query) {
            return tx.execute("CALL db.index.vector.queryNodes($name, $k, $query) YIELD node, score\nRETURN score, node ORDER BY score DESC\n", Map.of("name", INDEX_NAME, "k", k, "query", query));
        }

        private static Result setVectorProperty(Transaction tx, Node node, List<Double> vector) {
            return tx.execute("CALL db.create.setVectorProperty($node, $propKey, $vector)", (Map)Maps.mutable.of((Object)"node", (Object)node, (Object)"propKey", (Object)PROPERTY_KEY, (Object)PROPERTY_KEY, vector));
        }

        private static List<VectorIndexProcedures.Neighbor> collectNeighbors(Result results) {
            ArrayList<VectorIndexProcedures.Neighbor> neighbors = new ArrayList<VectorIndexProcedures.Neighbor>();
            results.accept(row -> neighbors.add(new VectorIndexProcedures.Neighbor(row.getNode("node"), row.getNumber("score").doubleValue())));
            return neighbors;
        }

        private static List<VectorIndexProcedures.NodeRecord> collectNodes(Result results) {
            ArrayList<VectorIndexProcedures.NodeRecord> nodes = new ArrayList<VectorIndexProcedures.NodeRecord>();
            results.accept(row -> nodes.add(new VectorIndexProcedures.NodeRecord(row.getNode("node"))));
            return nodes;
        }

        private List<VectorIndexProcedures.Neighbor> linearSearch(Label label, String propertyKey, float[] query, int k) {
            ArrayList<VectorIndexProcedures.Neighbor> results = new ArrayList<VectorIndexProcedures.Neighbor>();
            try (Transaction tx = this.db.beginTx();){
                ResourceIterator nodes = tx.findNodes(label);
                while (nodes.hasNext()) {
                    Node node = (Node)nodes.next();
                    Object value = node.getProperty(propertyKey);
                    if (!(value instanceof float[])) continue;
                    float[] vector = (float[])value;
                    float score = this.similarityFunction.compare(query, vector);
                    int rank = VectorIndexProceduresITBase.getRankInCurrentResults(results, score);
                    VectorIndexProceduresITBase.includeInNearestK(results, k, rank, node, score);
                }
            }
            return results;
        }

        private static void includeInNearestK(List<VectorIndexProcedures.Neighbor> results, int k, int insertAt, Node node, float score) {
            if (insertAt <= k) {
                results.add(insertAt, new VectorIndexProcedures.Neighbor(node, (double)score));
                if (results.size() > k) {
                    results.remove(results.size() - 1);
                }
            }
        }

        private static int getRankInCurrentResults(List<VectorIndexProcedures.Neighbor> current, float score) {
            for (int i = current.size(); i > 0; --i) {
                if (!(current.get(i - 1).score() > (double)score)) continue;
                return i;
            }
            return 0;
        }

        private static double recall(List<VectorIndexProcedures.Neighbor> approximate, List<VectorIndexProcedures.Neighbor> exact) {
            long trueNearestFound = exact.stream().filter(approximate::contains).count();
            return (double)trueNearestFound / (double)exact.size();
        }

        protected float[] randomVector() {
            return this.randomVector(1000);
        }

        private float[] randomVector(int dimensions) {
            float[] vector = new float[dimensions];
            for (int i = 0; i < vector.length; ++i) {
                vector[i] = this.random.nextFloat() - 0.5f;
            }
            return vector;
        }
    }
}

