/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.kernel.api.index;

import java.io.Serializable;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.assertj.core.api.AbstractThrowableAssert;
import org.assertj.core.api.Assertions;
import org.assertj.core.api.ThrowableAssert;
import org.eclipse.collections.api.block.function.Function;
import org.eclipse.collections.api.block.predicate.Predicate;
import org.eclipse.collections.api.list.ImmutableList;
import org.eclipse.collections.api.partition.list.PartitionImmutableList;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.condition.EnabledIf;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.neo4j.exceptions.KernelException;
import org.neo4j.graphdb.Label;
import org.neo4j.graphdb.RelationshipType;
import org.neo4j.graphdb.Transaction;
import org.neo4j.graphdb.schema.IndexCreator;
import org.neo4j.graphdb.schema.IndexDefinition;
import org.neo4j.graphdb.schema.IndexSetting;
import org.neo4j.graphdb.schema.IndexSettingUtil;
import org.neo4j.internal.schema.IndexConfig;
import org.neo4j.internal.schema.IndexPrototype;
import org.neo4j.internal.schema.IndexType;
import org.neo4j.internal.schema.SchemaDescriptor;
import org.neo4j.internal.schema.SchemaDescriptors;
import org.neo4j.kernel.KernelVersion;
import org.neo4j.kernel.KernelVersionProvider;
import org.neo4j.kernel.api.KernelTransaction;
import org.neo4j.kernel.api.impl.schema.vector.VectorIndexVersion;
import org.neo4j.kernel.api.vector.VectorSimilarityFunction;
import org.neo4j.kernel.impl.coreapi.InternalTransaction;
import org.neo4j.kernel.internal.GraphDatabaseAPI;
import org.neo4j.test.Tokens;
import org.neo4j.test.extension.ImpermanentDbmsExtension;
import org.neo4j.test.extension.Inject;

public class VectorIndexCreationTest {

    @Nested
    class RelIndex
    extends VectorIndexCreationTestBase {
        private static final RelationshipType REL_TYPE = Tokens.Factories.RELATIONSHIP_TYPE.fromName("VECTOR");
        private int relTypeId;

        RelIndex() {
            super(KernelVersion.VERSION_VECTOR_2_INTRODUCED);
        }

        @Override
        void setup(KernelTransaction ktx) throws KernelException {
            this.relTypeId = Tokens.Factories.RELATIONSHIP_TYPE.getId(ktx, REL_TYPE);
        }

        @Override
        protected SchemaDescriptor schemaDescriptor(int ... propKeyIds) {
            return SchemaDescriptors.forRelType((int)this.relTypeId, (int[])propKeyIds);
        }

        @Override
        protected IndexCreator indexCreator(Transaction tx) {
            return tx.schema().indexFor(REL_TYPE);
        }
    }

    @Nested
    class NodeIndex
    extends VectorIndexCreationTestBase {
        private static final Label LABEL = Tokens.Factories.LABEL.fromName("Vector");
        private int labelId;

        NodeIndex() {
            super(KernelVersion.VERSION_NODE_VECTOR_INDEX_INTRODUCED);
        }

        @Override
        void setup(KernelTransaction ktx) throws KernelException {
            this.labelId = Tokens.Factories.LABEL.getId(ktx, LABEL);
        }

        @Override
        protected SchemaDescriptor schemaDescriptor(int ... propKeyIds) {
            return SchemaDescriptors.forLabel((int)this.labelId, (int[])propKeyIds);
        }

        @Override
        protected IndexCreator indexCreator(Transaction tx) {
            return tx.schema().indexFor(LABEL);
        }
    }

    @ImpermanentDbmsExtension
    @TestInstance(value=TestInstance.Lifecycle.PER_CLASS)
    static abstract class VectorIndexCreationTestBase {
        protected static final List<String> PROP_KEYS = new Tokens.Suppliers.PropertyKey("vector", Tokens.Suppliers.Suffixes.incrementing()).get(2);
        private final ImmutableList<VectorIndexVersion> validVersions;
        private final ImmutableList<VectorIndexVersion> invalidVersions;
        @Inject
        private GraphDatabaseAPI db;
        private VectorIndexVersion latestSupportedVersion;
        protected int[] propKeyIds;

        VectorIndexCreationTestBase(KernelVersion introducedKernelVersion) {
            PartitionImmutableList partitioned = VectorIndexVersion.KNOWN_VERSIONS.partition((Predicate & Serializable)indexVersion -> indexVersion.minimumRequiredKernelVersion().isAtLeast(introducedKernelVersion));
            this.validVersions = partitioned.getSelected();
            this.invalidVersions = partitioned.getRejected();
        }

        @BeforeAll
        void setup() throws Exception {
            this.latestSupportedVersion = VectorIndexVersion.latestSupportedVersion((KernelVersion)((KernelVersionProvider)this.db.getDependencyResolver().resolveDependency(KernelVersionProvider.class)).kernelVersion());
            try (Transaction tx = this.db.beginTx();){
                KernelTransaction ktx = ((InternalTransaction)tx).kernelTransaction();
                this.propKeyIds = Tokens.Factories.PROPERTY_KEY.getIds(ktx, PROP_KEYS);
                this.setup(ktx);
                tx.commit();
            }
        }

        abstract void setup(KernelTransaction var1) throws KernelException;

        @BeforeEach
        void dropAllIndexes() {
            try (Transaction tx = this.db.beginTx();){
                tx.schema().getIndexes().forEach(IndexDefinition::drop);
                tx.commit();
            }
        }

        @ParameterizedTest
        @MethodSource(value={"validVersions"})
        @EnabledIf(value="hasValidVersions")
        void shouldAcceptTestDefaults(VectorIndexVersion version) {
            VectorIndexCreationTestBase.assertDoesNotThrow(() -> this.createVectorIndex(version, VectorIndexCreationTestBase.defaultConfig(), this.propKeyIds[0]));
        }

        @Test
        @EnabledIf(value="isLatestValid")
        void shouldAcceptTestDefaultsCoreAPI() {
            VectorIndexCreationTestBase.assertDoesNotThrow(() -> this.createVectorIndex(VectorIndexCreationTestBase.defaultSettings(), PROP_KEYS.get(1)));
        }

        @ParameterizedTest
        @MethodSource(value={"invalidVersions"})
        @EnabledIf(value="hasInvalidVersions")
        void shouldRejectVectorIndexOnUnsupportedVersions(VectorIndexVersion version) {
            VectorIndexCreationTestBase.assertUnsupportedIndex(() -> this.createVectorIndex(version, VectorIndexCreationTestBase.defaultConfig(), this.propKeyIds[0]));
        }

        @ParameterizedTest
        @MethodSource(value={"validVersions"})
        @EnabledIf(value="hasValidVersions")
        void shouldRejectCompositeKeys(VectorIndexVersion version) {
            VectorIndexCreationTestBase.assertUnsupportedComposite(() -> this.createVectorIndex(version, VectorIndexCreationTestBase.defaultConfig(), this.propKeyIds));
        }

        @Test
        @EnabledIf(value="isLatestValid")
        void shouldRejectCompositeKeysCoreAPI() {
            VectorIndexCreationTestBase.assertUnsupportedComposite(() -> this.createVectorIndex(VectorIndexCreationTestBase.defaultSettings(), PROP_KEYS));
        }

        @ParameterizedTest
        @MethodSource
        @EnabledIf(value="hasValidVersions")
        void shouldAcceptValidDimensions(VectorIndexVersion version, int dimensions) {
            IndexConfig config = VectorIndexCreationTestBase.defaultConfigWith(IndexSetting.vector_Dimensions(), dimensions);
            VectorIndexCreationTestBase.assertDoesNotThrow(() -> this.createVectorIndex(version, config, this.propKeyIds[0]));
        }

        Stream<Arguments> shouldAcceptValidDimensions() {
            return this.validVersions().flatMap(version -> VectorIndexCreationTestBase.validDimensions(version).mapToObj(dimension -> Arguments.of((Object[])new Object[]{version, dimension})));
        }

        @ParameterizedTest
        @MethodSource
        @EnabledIf(value="isLatestValid")
        void shouldAcceptValidDimensionsCoreAPI(int dimensions) {
            Map<IndexSetting, Object> settings = VectorIndexCreationTestBase.defaultSettingsWith(IndexSetting.vector_Dimensions(), dimensions);
            VectorIndexCreationTestBase.assertDoesNotThrow(() -> this.createVectorIndex(settings, PROP_KEYS.get(1)));
        }

        IntStream shouldAcceptValidDimensionsCoreAPI() {
            return VectorIndexCreationTestBase.validDimensions(this.latestSupportedVersion);
        }

        static IntStream validDimensions(VectorIndexVersion version) {
            int max = version.maxDimensions();
            return IntStream.of(1, 256, 512, 738, 1024, 1408, 1536, 2048, 3072, 4096, max).filter(d -> d <= max);
        }

        @ParameterizedTest
        @MethodSource
        @EnabledIf(value="hasValidVersions")
        void shouldRejectIllegalDimensions(VectorIndexVersion version, int dimensions) {
            IndexConfig config = VectorIndexCreationTestBase.defaultConfigWith(IndexSetting.vector_Dimensions(), dimensions);
            VectorIndexCreationTestBase.assertIllegalDimensions(() -> this.createVectorIndex(version, config, this.propKeyIds[0]));
        }

        Stream<Arguments> shouldRejectIllegalDimensions() {
            return this.validVersions().flatMap(version -> IntStream.of(-1, 0).mapToObj(dimension -> Arguments.of((Object[])new Object[]{version, dimension})));
        }

        @ParameterizedTest
        @ValueSource(ints={-1, 0})
        @EnabledIf(value="isLatestValid")
        void shouldRejectIllegalDimensionsCoreAPI(int dimensions) {
            Map<IndexSetting, Object> settings = VectorIndexCreationTestBase.defaultSettingsWith(IndexSetting.vector_Dimensions(), dimensions);
            VectorIndexCreationTestBase.assertIllegalDimensions(() -> this.createVectorIndex(settings, PROP_KEYS.get(1)));
        }

        @ParameterizedTest
        @MethodSource(value={"validVersions"})
        @EnabledIf(value="hasValidVersions")
        void shouldRejectUnsupportedDimensions(VectorIndexVersion version) {
            int dimensions = version.maxDimensions() + 1;
            IndexConfig config = VectorIndexCreationTestBase.defaultConfigWith(IndexSetting.vector_Dimensions(), dimensions);
            VectorIndexCreationTestBase.assertUnsupportedDimensions(() -> this.createVectorIndex(version, config, this.propKeyIds[0]));
        }

        @Test
        @EnabledIf(value="isLatestValid")
        void shouldRejectUnsupportedDimensionsCoreAPI() {
            int dimensions = this.latestSupportedVersion.maxDimensions() + 1;
            Map<IndexSetting, Object> settings = VectorIndexCreationTestBase.defaultSettingsWith(IndexSetting.vector_Dimensions(), dimensions);
            VectorIndexCreationTestBase.assertUnsupportedDimensions(() -> this.createVectorIndex(settings, PROP_KEYS.get(1)));
        }

        @ParameterizedTest
        @MethodSource
        @EnabledIf(value="hasValidVersions")
        void shouldAcceptValidSimilarityFunction(VectorIndexVersion version, VectorSimilarityFunction similarityFunction) {
            String similarityFunctionName = similarityFunction.name();
            IndexConfig config = VectorIndexCreationTestBase.defaultConfigWith(IndexSetting.vector_Similarity_Function(), similarityFunctionName);
            VectorIndexCreationTestBase.assertDoesNotThrow(() -> this.createVectorIndex(version, config, this.propKeyIds[0]));
        }

        private Stream<Arguments> shouldAcceptValidSimilarityFunction() {
            return this.validVersions().flatMap(version -> version.supportedSimilarityFunctions().asLazy().collect((Function & Serializable)similarityFunction -> Arguments.of((Object[])new Object[]{version, similarityFunction})).toList().stream());
        }

        @ParameterizedTest
        @MethodSource
        @EnabledIf(value="hasValidVersions")
        void shouldAcceptValidSimilarityFunctionCoreAPI(VectorSimilarityFunction similarityFunction) {
            String similarityFunctionName = similarityFunction.name();
            Map<IndexSetting, Object> settings = VectorIndexCreationTestBase.defaultSettingsWith(IndexSetting.vector_Similarity_Function(), similarityFunctionName);
            VectorIndexCreationTestBase.assertDoesNotThrow(() -> this.createVectorIndex(settings, PROP_KEYS.get(1)));
        }

        private Iterable<VectorSimilarityFunction> shouldAcceptValidSimilarityFunctionCoreAPI() {
            return this.latestSupportedVersion.supportedSimilarityFunctions();
        }

        @ParameterizedTest
        @MethodSource(value={"validVersions"})
        @EnabledIf(value="hasValidVersions")
        void shouldRejectIllegalSimilarityFunction(VectorIndexVersion version) {
            String similarityFunctionName = "ClearlyThisIsNotASimilarityFunction";
            IndexConfig config = VectorIndexCreationTestBase.defaultConfigWith(IndexSetting.vector_Similarity_Function(), "ClearlyThisIsNotASimilarityFunction");
            VectorIndexCreationTestBase.assertIllegalSimilarityFunction(version, () -> this.createVectorIndex(version, config, this.propKeyIds[0]));
        }

        @Test
        void shouldRejectIllegalSimilarityFunctionCoreAPI() {
            String similarityFunctionName = "ClearlyThisIsNotASimilarityFunction";
            Map<IndexSetting, Object> settings = VectorIndexCreationTestBase.defaultSettingsWith(IndexSetting.vector_Similarity_Function(), "ClearlyThisIsNotASimilarityFunction");
            VectorIndexCreationTestBase.assertIllegalSimilarityFunction(this.latestSupportedVersion, () -> this.createVectorIndex(settings, PROP_KEYS.get(1)));
        }

        private boolean isLatestValid() {
            return this.validVersions.contains((Object)this.latestSupportedVersion);
        }

        private boolean hasValidVersions() {
            return !this.validVersions.isEmpty();
        }

        private Stream<VectorIndexVersion> validVersions() {
            return this.validVersions.stream();
        }

        private boolean hasInvalidVersions() {
            return !this.invalidVersions.isEmpty();
        }

        private Stream<VectorIndexVersion> invalidVersions() {
            return this.invalidVersions.stream();
        }

        private void createVectorIndex(VectorIndexVersion version, IndexConfig config, int ... propKeyIds) throws KernelException {
            try (Transaction tx = this.db.beginTx();){
                KernelTransaction ktx = ((InternalTransaction)tx).kernelTransaction();
                IndexPrototype prototype = IndexPrototype.forSchema((SchemaDescriptor)this.schemaDescriptor(propKeyIds)).withIndexType(IndexType.VECTOR).withIndexProvider(version.descriptor()).withIndexConfig(config);
                ktx.schemaWrite().indexCreate(prototype);
                tx.commit();
            }
        }

        protected abstract SchemaDescriptor schemaDescriptor(int ... var1);

        private void createVectorIndex(Map<IndexSetting, Object> settings, String propKey) {
            this.createVectorIndex(settings, List.of(propKey));
        }

        private void createVectorIndex(Map<IndexSetting, Object> settings, List<String> propKeys) {
            try (Transaction tx = this.db.beginTx();){
                this.createVectorIndex(this.indexCreator(tx), settings, propKeys);
                tx.commit();
            }
        }

        private void createVectorIndex(IndexCreator creator, Map<IndexSetting, Object> settings, List<String> propKeys) {
            creator = creator.withIndexType(IndexType.VECTOR.toPublicApi()).withIndexConfiguration(settings);
            for (String propKey : propKeys) {
                creator = creator.on(propKey);
            }
            creator.create();
        }

        protected abstract IndexCreator indexCreator(Transaction var1);

        private static IndexConfig defaultConfig() {
            return IndexSettingUtil.defaultConfigForTest((org.neo4j.graphdb.schema.IndexType)IndexType.VECTOR.toPublicApi());
        }

        private static IndexConfig defaultConfigWith(IndexSetting setting, Object value) {
            return VectorIndexCreationTestBase.configFrom(VectorIndexCreationTestBase.defaultSettingsWith(setting, value));
        }

        private static IndexConfig configFrom(Map<IndexSetting, Object> settings) {
            return IndexSettingUtil.toIndexConfigFromIndexSettingObjectMap(settings);
        }

        private static Map<IndexSetting, Object> defaultSettings() {
            return IndexSettingUtil.defaultSettingsForTesting((org.neo4j.graphdb.schema.IndexType)IndexType.VECTOR.toPublicApi());
        }

        private static Map<IndexSetting, Object> defaultSettingsWith(IndexSetting setting, Object value) {
            HashMap<IndexSetting, Object> settings = new HashMap<IndexSetting, Object>(VectorIndexCreationTestBase.defaultSettings());
            settings.put(setting, value);
            return Collections.unmodifiableMap(settings);
        }

        private static void assertDoesNotThrow(ThrowableAssert.ThrowingCallable callable) {
            Assertions.assertThatCode((ThrowableAssert.ThrowingCallable)callable).doesNotThrowAnyException();
        }

        private static void assertUnsupportedIndex(ThrowableAssert.ThrowingCallable callable) {
            ((AbstractThrowableAssert)Assertions.assertThatThrownBy((ThrowableAssert.ThrowingCallable)callable).isInstanceOf(UnsupportedOperationException.class)).hasMessageContainingAll(new CharSequence[]{"vector indexes with provider", "are not supported"});
        }

        private static void assertIllegalDimensions(ThrowableAssert.ThrowingCallable callable) {
            ((AbstractThrowableAssert)Assertions.assertThatThrownBy((ThrowableAssert.ThrowingCallable)callable).isInstanceOf(IllegalArgumentException.class)).cause().hasMessageContainingAll(new CharSequence[]{IndexSetting.vector_Dimensions().getSettingName(), "is expected to be positive"});
        }

        private static void assertUnsupportedDimensions(ThrowableAssert.ThrowingCallable callable) {
            ((AbstractThrowableAssert)Assertions.assertThatThrownBy((ThrowableAssert.ThrowingCallable)callable).isInstanceOf(UnsupportedOperationException.class)).hasMessageContainingAll(new CharSequence[]{IndexSetting.vector_Dimensions().getSettingName(), "set greater than"});
        }

        private static void assertIllegalSimilarityFunction(VectorIndexVersion version, ThrowableAssert.ThrowingCallable callable) {
            ((AbstractThrowableAssert)Assertions.assertThatThrownBy((ThrowableAssert.ThrowingCallable)callable).isInstanceOf(IllegalArgumentException.class)).cause().hasMessageContainingAll(new CharSequence[]{"is an unsupported vector similarity function", "Supported", version.supportedSimilarityFunctions().asLazy().collect(VectorSimilarityFunction::name).toString()});
        }

        private static void assertUnsupportedComposite(ThrowableAssert.ThrowingCallable callable) {
            ((AbstractThrowableAssert)Assertions.assertThatThrownBy((ThrowableAssert.ThrowingCallable)callable).isInstanceOf(UnsupportedOperationException.class)).hasMessageContainingAll(new CharSequence[]{"Composite indexes are not supported for", IndexType.VECTOR.name(), "index type"});
        }
    }
}

