/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.kernel.impl.newapi;

import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collector;
import java.util.stream.IntStream;
import org.apache.commons.collections4.map.ListOrderedMap;
import org.assertj.core.api.AbstractBooleanAssert;
import org.assertj.core.api.AbstractCollectionAssert;
import org.assertj.core.api.AbstractIntegerAssert;
import org.assertj.core.api.AbstractLongAssert;
import org.assertj.core.api.AbstractThrowableAssert;
import org.assertj.core.api.Assumptions;
import org.assertj.core.api.BooleanAssert;
import org.assertj.core.api.CollectionAssert;
import org.assertj.core.api.IntegerAssert;
import org.assertj.core.api.IterableAssert;
import org.assertj.core.api.SoftAssertions;
import org.assertj.core.api.junit.jupiter.InjectSoftAssertions;
import org.assertj.core.api.junit.jupiter.SoftAssertionsExtension;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.neo4j.common.EntityType;
import org.neo4j.exceptions.KernelException;
import org.neo4j.graphdb.GraphDatabaseService;
import org.neo4j.internal.kernel.api.Cursor;
import org.neo4j.internal.kernel.api.PartitionedScan;
import org.neo4j.internal.kernel.api.SchemaWrite;
import org.neo4j.internal.kernel.api.exceptions.InvalidTransactionTypeKernelException;
import org.neo4j.internal.kernel.api.exceptions.schema.IndexNotApplicableKernelException;
import org.neo4j.internal.schema.IndexDescriptor;
import org.neo4j.internal.schema.IndexPrototype;
import org.neo4j.internal.schema.SchemaDescriptor;
import org.neo4j.internal.schema.SchemaDescriptors;
import org.neo4j.kernel.api.ExecutionContext;
import org.neo4j.kernel.api.Kernel;
import org.neo4j.kernel.api.KernelTransaction;
import org.neo4j.kernel.api.Statement;
import org.neo4j.kernel.api.WorkerContext;
import org.neo4j.kernel.impl.coreapi.TransactionImpl;
import org.neo4j.kernel.impl.coreapi.schema.SchemaImpl;
import org.neo4j.kernel.impl.newapi.PartitionedScanFactories;
import org.neo4j.kernel.impl.newapi.TestUtils;
import org.neo4j.storageengine.api.StorageEngine;
import org.neo4j.test.Race;
import org.neo4j.test.RandomSupport;
import org.neo4j.test.Tags;
import org.neo4j.test.extension.ImpermanentDbmsExtension;
import org.neo4j.test.extension.Inject;
import org.neo4j.test.extension.RandomExtension;

@ExtendWith(value={SoftAssertionsExtension.class, RandomExtension.class})
@ImpermanentDbmsExtension
@TestInstance(value=TestInstance.Lifecycle.PER_CLASS)
abstract class PartitionedScanTestSuite<QUERY extends Query<?>, SESSION, CURSOR extends Cursor> {
    @Inject
    private GraphDatabaseService db;
    @Inject
    protected RandomSupport random;
    @InjectSoftAssertions
    protected SoftAssertions softly;
    @Inject
    protected StorageEngine storageEngine;
    @Inject
    protected Kernel kernel;
    protected Queries<QUERY> queries;
    protected int maxNumberOfPartitions;
    protected PartitionedScanFactories.PartitionedScanFactory<QUERY, SESSION, CURSOR> factory;

    abstract Queries<QUERY> setupDatabase();

    PartitionedScanTestSuite(TestSuite<QUERY, SESSION, CURSOR> testSuite) {
        this.factory = testSuite.getFactory();
    }

    @BeforeAll
    protected void setup() {
        this.queries = this.setupDatabase();
        ((IterableAssert)Assumptions.assumeThat(this.queries.valid()).as("there are valid queries to test against", new Object[0])).isNotEmpty();
        this.maxNumberOfPartitions = this.calculateMaxNumberOfPartitions(this.queries.valid().queries());
    }

    protected final KernelTransaction beginTx() {
        return ((TransactionImpl)this.db.beginTx()).kernelTransaction();
    }

    @Test
    final void shouldThrowWithEntityTypeComplementSeekOrScan() throws KernelException {
        try (KernelTransaction tx = this.beginTx();){
            QUERY query = this.getFirstValidQuery();
            ((AbstractThrowableAssert)this.softly.assertThatThrownBy(() -> this.factory.getEntityTypeComplimentFactory().partitionedScan(tx, this.factory.getSession(tx, query.indexName()), Integer.MAX_VALUE, (Query)query), "should throw with mismatched entity type seek/scan method, and given index session", new Object[0]).isInstanceOf(IndexNotApplicableKernelException.class)).hasMessageContaining("can not be performed on index");
        }
    }

    @ParameterizedTest
    @ValueSource(ints={-1, 0})
    final void shouldThrowWithNonPositivePartitions(int desiredNumberOfPartitions) throws KernelException {
        try (KernelTransaction tx = this.beginTx();){
            ((AbstractThrowableAssert)this.softly.assertThatThrownBy(() -> this.factory.partitionedScan(tx, desiredNumberOfPartitions, this.getFirstValidQuery()), "desired number of partitions must be positive", new Object[0]).isInstanceOf(IllegalArgumentException.class)).hasMessageContainingAll(new CharSequence[]{"Expected positive", "value"});
        }
    }

    @Test
    final void shouldThrowOnConstructionWithTransactionState() throws KernelException {
        try (KernelTransaction tx = this.beginTx();){
            PartitionedScanTestSuite.createState(tx);
            ((BooleanAssert)this.softly.assertThat(tx.dataRead().transactionStateHasChanges()).as("transaction state", new Object[0])).isTrue();
            ((AbstractThrowableAssert)this.softly.assertThatThrownBy(() -> this.factory.partitionedScan(tx, Integer.MAX_VALUE, this.getFirstValidQuery()), "should throw on construction of scan, with transaction state", new Object[0]).isInstanceOf(IllegalStateException.class)).hasMessage("Transaction contains changes; PartitionScan is only valid in Read-Only transactions.");
        }
    }

    @Test
    final void shouldThrowWithInvalidQuery() throws KernelException {
        ((AbstractCollectionAssert)Assumptions.assumeThat(this.queries.invalid()).as("there are invalid queries to test against", new Object[0])).isNotEmpty();
        try (KernelTransaction tx = this.beginTx();){
            for (Query query : this.queries.invalid()) {
                ((AbstractThrowableAssert)this.softly.assertThatThrownBy(() -> this.factory.partitionedScan(tx, Integer.MAX_VALUE, query), "should throw with an invalid query", new Object[0]).isInstanceOf(IndexNotApplicableKernelException.class)).hasMessageContaining("This index does not support partitioned scan for this query");
            }
        }
    }

    private QUERY getFirstValidQuery() {
        return (QUERY)((Query)this.queries.valid().iterator().next().getKey());
    }

    protected String getTokenIndexName(EntityType entityType) {
        String string;
        block8: {
            KernelTransaction tx = this.beginTx();
            try {
                Iterator indexes = tx.schemaRead().index((SchemaDescriptor)SchemaDescriptors.forAnyEntityTokens((EntityType)entityType));
                ((AbstractBooleanAssert)Assumptions.assumeThat((boolean)indexes.hasNext()).as("%s based token index exists", new Object[]{entityType})).isTrue();
                IndexDescriptor index = (IndexDescriptor)indexes.next();
                ((AbstractBooleanAssert)Assumptions.assumeThat((boolean)indexes.hasNext()).as("only one %s based token index exists", new Object[]{entityType})).isFalse();
                string = index.getName();
                if (tx == null) break block8;
            }
            catch (Throwable throwable) {
                try {
                    if (tx != null) {
                        try {
                            tx.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                catch (Exception e) {
                    throw new AssertionError(String.format("failed to get %s based token index", entityType), e);
                }
            }
            tx.close();
        }
        return string;
    }

    protected void createIndexes(Iterable<IndexPrototype> indexPrototypes) {
        KernelTransaction tx;
        try {
            tx = this.beginTx();
            try {
                SchemaWrite schemaWrite = tx.schemaWrite();
                for (IndexPrototype indexPrototype : indexPrototypes) {
                    schemaWrite.indexCreate(indexPrototype);
                }
                tx.commit();
            }
            finally {
                if (tx != null) {
                    tx.close();
                }
            }
        }
        catch (Exception e) {
            throw new AssertionError("failed to create indexes", e);
        }
        try {
            tx = this.beginTx();
            try {
                new SchemaImpl(tx).awaitIndexesOnline(1L, TimeUnit.HOURS);
            }
            finally {
                if (tx != null) {
                    tx.close();
                }
            }
        }
        catch (Exception e) {
            throw new AssertionError("failed waiting for indexes to come online", e);
        }
    }

    protected int calculateMaxNumberOfPartitions(Iterable<QUERY> queries) {
        int n;
        block9: {
            KernelTransaction tx = this.beginTx();
            try {
                int maxNumberOfPartitions = 0;
                for (Query query : queries) {
                    maxNumberOfPartitions = Math.max(maxNumberOfPartitions, this.factory.partitionedScan(tx, Integer.MAX_VALUE, query).getNumberOfPartitions());
                }
                n = maxNumberOfPartitions;
                if (tx == null) break block9;
            }
            catch (Throwable throwable) {
                try {
                    if (tx != null) {
                        try {
                            tx.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                catch (Exception e) {
                    throw new AssertionError("failed to calculated max number of partitions", e);
                }
            }
            tx.close();
        }
        return n;
    }

    private static void createState(KernelTransaction tx) throws InvalidTransactionTypeKernelException {
        tx.dataWrite().nodeCreate();
    }

    protected final <TAG> int createTag(Tags.Suppliers.Supplier<TAG> tag) {
        int tagId;
        try (KernelTransaction tx = this.beginTx();){
            tagId = tag.getId(tx);
            tx.commit();
        }
        catch (KernelException e) {
            throw new AssertionError(String.format("failed to create %ss in database", tag.name()), e);
        }
        return tagId;
    }

    protected final <TAG> List<Integer> createTags(int numberOfTags, Tags.Suppliers.Supplier<TAG> tag) {
        List tagIds;
        try (KernelTransaction tx = this.beginTx();){
            tagIds = tag.getIds(tx, numberOfTags);
            tx.commit();
        }
        catch (KernelException e) {
            throw new AssertionError(String.format("failed to create %ss in database", tag.name()), e);
        }
        return tagIds;
    }

    static interface TestSuite<QUERY extends Query<?>, SESSION, CURSOR extends Cursor> {
        public PartitionedScanFactories.PartitionedScanFactory<QUERY, SESSION, CURSOR> getFactory();
    }

    protected record Queries<QUERY extends Query<?>>(EntityIdsMatchingQuery<QUERY> valid, Set<QUERY> invalid) {
        public Queries(EntityIdsMatchingQuery<QUERY> valid, Set<QUERY> invalid) {
            this.valid = valid;
            this.invalid = Collections.unmodifiableSet(invalid);
        }

        public Queries(EntityIdsMatchingQuery<QUERY> valid) {
            this(valid, Set.of());
        }
    }

    protected static final class EntityIdsMatchingQuery<QUERY extends Query<?>>
    implements Iterable<Map.Entry<QUERY, Set<Long>>> {
        private final Map<QUERY, Set<Long>> matches = new ListOrderedMap();

        protected EntityIdsMatchingQuery() {
        }

        static <QUERY extends Query<?>> Collector<QUERY, EntityIdsMatchingQuery<QUERY>, EntityIdsMatchingQuery<QUERY>> collector() {
            return Collector.of(EntityIdsMatchingQuery::new, EntityIdsMatchingQuery::getOrCreate, EntityIdsMatchingQuery::addAll, new Collector.Characteristics[0]);
        }

        Set<Long> getOrCreate(QUERY query) {
            return this.matches.computeIfAbsent(query, q -> new HashSet());
        }

        Set<Long> addOrReplace(QUERY query, Set<Long> entityIds) {
            return this.matches.put(query, entityIds);
        }

        EntityIdsMatchingQuery<QUERY> addAll(EntityIdsMatchingQuery<QUERY> other) {
            this.matches.putAll(other.matches);
            return this;
        }

        Set<QUERY> queries() {
            return Collections.unmodifiableMap(this.matches).keySet();
        }

        @Override
        public Iterator<Map.Entry<QUERY, Set<Long>>> iterator() {
            return Collections.unmodifiableMap(this.matches).entrySet().iterator();
        }
    }

    protected static interface Query<QUERY> {
        public String indexName();

        public QUERY get();
    }

    protected record Range(long min, long max) {
        boolean contains(long value) {
            return this.min <= value && value < this.max;
        }

        long quantile(long n, long q) {
            ((AbstractLongAssert)Assumptions.assumeThat((long)n).as("given numbered quantile, is a valid quantile", new Object[0])).isBetween(Long.valueOf(0L), Long.valueOf(q));
            return this.min + n * (this.max - this.min) / q;
        }

        long random(Random random) {
            return random.nextLong(this.min, this.max);
        }

        long randomBetweenQuantiles(Random random, long n, long m, long q) {
            return Range.createSane(this.quantile(n, q), this.quantile(m, q)).random(random);
        }

        static Range createSane(long x, long y) {
            return x < y ? new Range(x, y) : new Range(y, x);
        }

        static Range union(Range lhs, Range rhs) {
            if (lhs == null) {
                return rhs;
            }
            if (rhs == null) {
                return lhs;
            }
            return new Range(Math.min(lhs.min, rhs.min), Math.max(lhs.max, rhs.max));
        }

        static boolean strictlyLessThan(Range lhs, Range rhs) {
            return lhs.min <= lhs.max && rhs.min <= rhs.max && lhs.max <= rhs.min;
        }
    }

    static abstract class WithData<QUERY extends Query<?>, SESSION, CURSOR extends Cursor>
    extends PartitionedScanTestSuite<QUERY, SESSION, CURSOR> {
        WithData(TestSuite<QUERY, SESSION, CURSOR> testSuite) {
            super(testSuite);
        }

        @Override
        @BeforeAll
        protected void setup() {
            super.setup();
            ((AbstractIntegerAssert)Assumptions.assumeThat((int)this.maxNumberOfPartitions).as("max number of partitions is enough to test partitions", new Object[0])).isGreaterThan(1);
        }

        @ParameterizedTest
        @EnumSource(value=TestUtils.PartitionedScanAPI.class)
        final void shouldScanSubsetOfEntriesWithSinglePartition(TestUtils.PartitionedScanAPI api) throws KernelException {
            try (KernelTransaction tx = this.beginTx();
                 Object entities = this.factory.getCursor(tx.cursors()).with(tx.cursorContext());
                 Statement statement = tx.acquireStatement();
                 ExecutionContext executionContext = tx.createExecutionContext();){
                for (Map.Entry entry : this.queries.valid()) {
                    Query query = (Query)entry.getKey();
                    Set<Long> expectedMatches = entry.getValue();
                    PartitionedScan scan = this.factory.partitionedScan(tx, this.maxNumberOfPartitions, query);
                    ((IntegerAssert)((IntegerAssert)this.softly.assertThat(scan.getNumberOfPartitions()).as("number of partitions", new Object[0])).isGreaterThan(0)).isLessThanOrEqualTo(this.maxNumberOfPartitions);
                    HashSet<Long> found = new HashSet<Long>();
                    api.reservePartition(scan, entities, tx, executionContext);
                    while (entities.next()) {
                        ((BooleanAssert)this.softly.assertThat(found.add(this.factory.getEntityReference(entities))).as("no duplicate", new Object[0])).isTrue();
                    }
                    if (expectedMatches.containsAll(found)) continue;
                    ((CollectionAssert)this.softly.assertThat(expectedMatches).as("subset of all matches for %s", new Object[]{query})).containsAll(found);
                }
                executionContext.complete();
            }
        }

        @ParameterizedTest
        @EnumSource(value=TestUtils.PartitionedScanAPI.class)
        final void shouldCreateNoMorePartitionsThanPossible(TestUtils.PartitionedScanAPI api) throws KernelException {
            this.singleThreadedCheck(api, Integer.MAX_VALUE);
        }

        @ParameterizedTest(name="desiredNumberOfPartitions={0}")
        @MethodSource(value={"rangeFromOneToMaxPartitions"})
        final void shouldScanAllEntriesWithGivenNumberOfPartitionsSingleThreaded(int desiredNumberOfPartitions) throws KernelException {
            this.singleThreadedCheck(TestUtils.PartitionedScanAPI.NEW, desiredNumberOfPartitions);
        }

        @ParameterizedTest(name="desiredNumberOfPartitions={0}")
        @MethodSource(value={"rangeFromOneToMaxPartitions"})
        final void shouldScanMultiplePartitionsInParallelWithSameNumberOfThreads(int desiredNumberOfPartitions) throws KernelException {
            this.multiThreadedCheck(desiredNumberOfPartitions, desiredNumberOfPartitions);
        }

        @ParameterizedTest(name="desiredNumberOfThreads={0}")
        @MethodSource(value={"rangeFromOneToMaxPartitions"})
        final void shouldScanMultiplePartitionsInParallelWithFewerThreads(int desiredNumberOfTheads) throws KernelException {
            this.multiThreadedCheck(this.maxNumberOfPartitions, desiredNumberOfTheads);
        }

        private void singleThreadedCheck(TestUtils.PartitionedScanAPI api, int desiredNumberOfPartitions) throws KernelException {
            try (KernelTransaction tx = this.beginTx();
                 Object entities = this.factory.getCursor(tx.cursors()).with(tx.cursorContext());
                 Statement statement = tx.acquireStatement();
                 ExecutionContext executionContext = tx.createExecutionContext();){
                for (Map.Entry entry : this.queries.valid()) {
                    Query query = (Query)entry.getKey();
                    Set<Long> expectedMatches = entry.getValue();
                    PartitionedScan scan = this.factory.partitionedScan(tx, desiredNumberOfPartitions, query);
                    ((IntegerAssert)((IntegerAssert)((IntegerAssert)this.softly.assertThat(scan.getNumberOfPartitions()).as("number of partitions", new Object[0])).isGreaterThan(0)).isLessThanOrEqualTo(desiredNumberOfPartitions)).isLessThanOrEqualTo(this.maxNumberOfPartitions);
                    HashSet<Long> found = new HashSet<Long>();
                    while (api.reservePartition(scan, entities, tx, executionContext)) {
                        while (entities.next()) {
                            ((BooleanAssert)this.softly.assertThat(found.add(this.factory.getEntityReference(entities))).as("no duplicate", new Object[0])).isTrue();
                        }
                    }
                    if (expectedMatches.equals(found)) continue;
                    ((CollectionAssert)this.softly.assertThat(found).as("only the expected data found matching %s", new Object[]{query})).containsExactlyInAnyOrderElementsOf(expectedMatches);
                }
                executionContext.complete();
            }
        }

        private void multiThreadedCheck(int desiredNumberOfPartitions, int numberOfThreads) throws KernelException {
            try (KernelTransaction tx = this.beginTx();){
                for (Map.Entry entry : this.queries.valid()) {
                    Query query = (Query)entry.getKey();
                    Set<Long> expectedMatches = entry.getValue();
                    PartitionedScan scan = this.factory.partitionedScan(tx, desiredNumberOfPartitions, query);
                    ((IntegerAssert)((IntegerAssert)((IntegerAssert)this.softly.assertThat(scan.getNumberOfPartitions()).as("number of partitions", new Object[0])).isGreaterThan(0)).isLessThanOrEqualTo(desiredNumberOfPartitions)).isLessThanOrEqualTo(this.maxNumberOfPartitions);
                    Set allFound = Collections.synchronizedSet(new HashSet());
                    List<WorkerContext<Cursor>> workerContexts = TestUtils.createContexts(tx, this.factory.getCursor(this.kernel.cursors())::with, numberOfThreads);
                    Race race = new Race();
                    for (WorkerContext<Cursor> workerContext : workerContexts) {
                        race.addContestant(() -> {
                            ExecutionContext executionContext = workerContext.getContext();
                            try (Cursor entities = (Cursor)workerContext.getCursor();){
                                HashSet<Long> found = new HashSet<Long>();
                                while (scan.reservePartition(entities, executionContext)) {
                                    while (entities.next()) {
                                        ((BooleanAssert)this.softly.assertThat(found.add(this.factory.getEntityReference(entities))).as("no duplicate", new Object[0])).isTrue();
                                    }
                                }
                                found.forEach(s -> ((BooleanAssert)this.softly.assertThat(allFound.add(s)).as("no duplicates", new Object[0])).isTrue());
                            }
                            finally {
                                executionContext.complete();
                            }
                        });
                    }
                    race.goUnchecked();
                    workerContexts.forEach(WorkerContext::close);
                    if (expectedMatches.equals(allFound)) continue;
                    ((CollectionAssert)this.softly.assertThat(allFound).as("only the expected data found matching %s", new Object[]{query})).containsExactlyInAnyOrderElementsOf(expectedMatches);
                }
            }
        }

        protected IntStream rangeFromOneToMaxPartitions() {
            return IntStream.rangeClosed(1, this.maxNumberOfPartitions);
        }
    }

    static abstract class WithoutData<QUERY extends Query<?>, SESSION, CURSOR extends Cursor>
    extends PartitionedScanTestSuite<QUERY, SESSION, CURSOR> {
        WithoutData(TestSuite<QUERY, SESSION, CURSOR> testSuite) {
            super(testSuite);
        }

        @ParameterizedTest
        @EnumSource(value=TestUtils.PartitionedScanAPI.class)
        final void shouldHandleEmptyDatabase(TestUtils.PartitionedScanAPI api) throws KernelException {
            try (KernelTransaction tx = this.beginTx();
                 Object entities = this.factory.getCursor(tx.cursors()).with(tx.cursorContext());
                 Statement statement = tx.acquireStatement();
                 ExecutionContext executionContext = tx.createExecutionContext();){
                for (Map.Entry entry : this.queries.valid()) {
                    Query query = (Query)entry.getKey();
                    PartitionedScan scan = this.factory.partitionedScan(tx, Integer.MAX_VALUE, query);
                    while (api.reservePartition(scan, entities, tx, executionContext)) {
                        ((BooleanAssert)this.softly.assertThat(entities.next()).as("no data should be found for %s", new Object[]{query})).isFalse();
                    }
                }
                executionContext.complete();
            }
        }
    }
}

