package cn.boboweike.carrot.storage.nosql.mongo;

import cn.boboweike.carrot.lock.LockProvider;
import cn.boboweike.carrot.lock.nosql.MongoLockProvider;
import cn.boboweike.carrot.scheduling.partition.Partitioner;
import cn.boboweike.carrot.scheduling.partition.RandomPartitioner;
import cn.boboweike.carrot.storage.*;
import cn.boboweike.carrot.storage.StorageProviderUtils.DatabaseOptions;
import cn.boboweike.carrot.storage.StorageProviderUtils.Tasks;
import cn.boboweike.carrot.storage.StorageProviderUtils.RecurringTasks;
import cn.boboweike.carrot.storage.StorageProviderUtils.BackgroundTaskServers;
import cn.boboweike.carrot.storage.StorageProviderUtils.Metadata;
import cn.boboweike.carrot.storage.TaskStats;
import cn.boboweike.carrot.storage.nosql.mongo.mapper.BackgroundTaskServerStatusDocumentMapper;
import cn.boboweike.carrot.storage.nosql.mongo.mapper.MetadataDocumentMapper;
import cn.boboweike.carrot.storage.nosql.mongo.mapper.MongoDBPageRequestMapper;
import cn.boboweike.carrot.storage.nosql.mongo.mapper.TaskDocumentMapper;
import cn.boboweike.carrot.tasks.*;
import cn.boboweike.carrot.tasks.mappers.TaskMapper;
import cn.boboweike.carrot.tasks.states.StateName;
import cn.boboweike.carrot.utils.resilience.RateLimiter;
import com.mongodb.*;
import com.mongodb.bulk.BulkWriteResult;
import com.mongodb.client.MongoClient;
import com.mongodb.client.MongoClients;
import com.mongodb.client.MongoCollection;
import com.mongodb.client.MongoDatabase;
import com.mongodb.client.model.*;
import com.mongodb.client.result.DeleteResult;
import com.mongodb.client.result.InsertOneResult;
import com.mongodb.client.result.UpdateResult;
import org.bson.Document;
import org.bson.UuidRepresentation;
import org.bson.codecs.UuidCodec;
import org.bson.codecs.configuration.CodecRegistries;
import org.bson.codecs.configuration.CodecRegistry;
import org.bson.conversions.Bson;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.*;
import java.util.function.BiFunction;
import java.util.function.Predicate;
import java.util.stream.Collectors;

import static cn.boboweike.carrot.CarrotException.shouldNotHappenException;
import static cn.boboweike.carrot.storage.CarrotMetadata.toId;
import static cn.boboweike.carrot.storage.StorageProviderUtils.DatabaseOptions.CREATE;
import static cn.boboweike.carrot.storage.StorageProviderUtils.elementPrefixer;
import static cn.boboweike.carrot.storage.StorageProviderUtils.elementPrefixerWithPartition;
import static cn.boboweike.carrot.storage.StorageProviderUtils.*;
import static cn.boboweike.carrot.tasks.states.StateName.*;
import static cn.boboweike.carrot.utils.TaskUtils.getTaskSignature;
import static cn.boboweike.carrot.utils.reflection.ReflectionUtils.findMethod;
import static cn.boboweike.carrot.utils.resilience.RateLimiter.Builder.rateLimit;
import static cn.boboweike.carrot.utils.resilience.RateLimiter.SECOND;
import static com.mongodb.client.model.Aggregates.*;
import static com.mongodb.client.model.Filters.*;
import static com.mongodb.client.model.Projections.include;
import static com.mongodb.client.model.Sorts.ascending;
import static java.util.Arrays.asList;
import static java.util.Arrays.stream;
import static java.util.Optional.ofNullable;

public class MongoDBPartitionedStorageProvider extends AbstractPartitionedStorageProvider implements PartitionedStorageProvider {

    public static final String DEFAULT_DB_NAME = "carrot";

    private static final MongoDBPageRequestMapper pageRequestMapper = new MongoDBPageRequestMapper();

    private final String databaseName;
    private final MongoClient mongoClient;
    private final MongoDatabase carrotDatabase;
    private final Map<Integer, MongoCollection<org.bson.Document>> taskCollectionMap = new HashMap<>();
    private final Map<Integer, MongoCollection<org.bson.Document>> recurringTaskCollectionMap = new HashMap<>();
    private final MongoCollection<Document> metadataCollection;
    private final MongoCollection<Document> backgroundTaskServerCollection;
    private final LockProvider lockProvider;
    private final int totalNumOfPartitions;
    private final String collectionPrefix;

    private Partitioner partitioner;

    static final String ERR_MSG_INVALID_PARTITION = "invalid partition {%s}, DB operation will be ignored";

    private TaskDocumentMapper taskDocumentMapper;
    private BackgroundTaskServerStatusDocumentMapper backgroundTaskServerStatusDocumentMapper;
    private MetadataDocumentMapper metadataDocumentMapper;

    public MongoDBPartitionedStorageProvider(String connectionString) {
        this(connectionString, DEFAULT_NUM_OF_PARTITIONS);
    }

    public MongoDBPartitionedStorageProvider(String connectionString, int totalNumOfPartitions) {
        this(MongoClients.create(
                MongoClientSettings.builder()
                        .applyConnectionString(new ConnectionString(connectionString))
                        .codecRegistry(CodecRegistries.fromRegistries(
                                CodecRegistries.fromCodecs(new UuidCodec(UuidRepresentation.STANDARD)),
                                MongoClientSettings.getDefaultCodecRegistry()
                        )).build()),
                totalNumOfPartitions);
    }

    public MongoDBPartitionedStorageProvider(MongoClient mongoClient, int totalNumOfPartitions) {
        this(
                mongoClient,
                rateLimit().at1Request().per(SECOND),
                totalNumOfPartitions
        );
    }

    public MongoDBPartitionedStorageProvider(MongoClient mongoClient, String dbName, int totalNumOfPartitions) {
        this(
                mongoClient,
                dbName,
                null,
                CREATE,
                rateLimit().at1Request().per(SECOND),
                totalNumOfPartitions
        );
    }

    public MongoDBPartitionedStorageProvider(
            MongoClient mongoClient,
            String dbName,
            DatabaseOptions databaseOptions,
            int totalNumOfPartitions) {
        this(
                mongoClient,
                dbName,
                null,
                databaseOptions,
                rateLimit().at1Request().per(SECOND),
                totalNumOfPartitions
        );
    }

    public MongoDBPartitionedStorageProvider(
            MongoClient mongoClient,
            String dbName,
            String collectionPrefix,
            int totalNumOfPartitions
    ) {
        this(mongoClient,
                dbName,
                collectionPrefix,
                CREATE,
                rateLimit().at1Request().per(SECOND),
                totalNumOfPartitions);
    }

    public MongoDBPartitionedStorageProvider(
            MongoClient mongoClient,
            String dbName,
            String collectionPrefix,
            DatabaseOptions databaseOptions,
            int totalNumOfPartitions) {
        this(mongoClient,
                dbName,
                collectionPrefix,
                databaseOptions,
                rateLimit().at1Request().per(SECOND),
                totalNumOfPartitions);
    }

    public MongoDBPartitionedStorageProvider(
            MongoClient mongoClient,
            RateLimiter changeListenerNotificationRateLimit,
            int totalNumOfPartitions
    ) {
        this(mongoClient,
                null,
                null,
                CREATE,
                changeListenerNotificationRateLimit,
                totalNumOfPartitions);
    }

    public MongoDBPartitionedStorageProvider(
            MongoClient mongoClient,
            DatabaseOptions databaseOptions,
            RateLimiter changeListenerNotificationRateLimit,
            int totalNumOfPartitions) {
        this(mongoClient,
                null,
                null,
                databaseOptions,
                changeListenerNotificationRateLimit,
                totalNumOfPartitions);
    }

    public MongoDBPartitionedStorageProvider(
            MongoClient mongoClient,
            String dbName,
            String collectionPrefix,
            DatabaseOptions databaseOptions,
            RateLimiter changeListenerNotificationRateLimit,
            int totalNumOfPartitions) {
        super(changeListenerNotificationRateLimit);
        validateMongoClient(mongoClient);

        this.databaseName = ofNullable(dbName).orElse(DEFAULT_DB_NAME);
        this.collectionPrefix = collectionPrefix;
        this.mongoClient = mongoClient;
        if (totalNumOfPartitions < 1) {
            throw new IllegalArgumentException("The totalNumOfPartitions can not be smaller than 1!");
        }
        this.totalNumOfPartitions = totalNumOfPartitions;
        this.partitioner = new RandomPartitioner(this.totalNumOfPartitions);

        carrotDatabase = mongoClient.getDatabase(databaseName);

        setUpStorageProvider(databaseOptions);

        for (int partition = 0; partition < this.totalNumOfPartitions; partition++) {
            MongoCollection<Document> taskCollection =
                    carrotDatabase.getCollection(elementPrefixerWithPartition(collectionPrefix, Tasks.NAME, partition), Document.class);
            taskCollectionMap.put(partition, taskCollection);
            MongoCollection<Document> recurringTaskCollection =
                    carrotDatabase.getCollection(elementPrefixerWithPartition(collectionPrefix, RecurringTasks.NAME, partition), Document.class);
            recurringTaskCollectionMap.put(partition, recurringTaskCollection);
        }

        backgroundTaskServerCollection = carrotDatabase.getCollection(elementPrefixer(collectionPrefix, BackgroundTaskServers.NAME), Document.class);
        metadataCollection = carrotDatabase.getCollection(elementPrefixer(collectionPrefix, Metadata.NAME), Document.class);
        MongoCollection<Document> shedLockCollection = carrotDatabase.getCollection(elementPrefixer(collectionPrefix, ShedLock.NAME), Document.class);
        lockProvider = new MongoLockProvider(shedLockCollection);
    }

    @Override
    public int getTotalNumOfPartitions() {
        return totalNumOfPartitions;
    }

    @Override
    public boolean lockByPartition(Integer partition, int durationInSeconds, String lockedBy) {
        return this.lockProvider.lock(PARTITION_PREFIX + partition, durationInSeconds, lockedBy);
    }

    @Override
    public boolean extendLockByPartition(Integer partition, int durationInSeconds, String lockedBy) {
        return this.lockProvider.extend(PARTITION_PREFIX + partition, durationInSeconds, lockedBy);
    }

    @Override
    public boolean unlockByPartition(Integer partition) {
        return this.lockProvider.unlock(PARTITION_PREFIX + partition);
    }

    @Override
    public void setTaskMapper(TaskMapper taskMapper) {
        this.taskDocumentMapper = new TaskDocumentMapper(taskMapper);
        this.backgroundTaskServerStatusDocumentMapper = new BackgroundTaskServerStatusDocumentMapper();
        this.metadataDocumentMapper = new MetadataDocumentMapper();
    }

    @Override
    public void setPartitioner(Partitioner partitioner) {
        this.partitioner = partitioner;
    }

    @Override
    public void setUpStorageProvider(StorageProviderUtils.DatabaseOptions databaseOptions) {
        if (CREATE == databaseOptions) {
            runMigrations();
        } else {
            validateTables();
        }
    }

    @Override
    public void announceBackgroundTaskServer(BackgroundTaskServerStatus serverStatus) {
        InsertOneResult result = this.backgroundTaskServerCollection.insertOne(backgroundTaskServerStatusDocumentMapper.toInsertDocument(serverStatus));
        if (!result.wasAcknowledged()) {
            throw new StorageException("Unable to announce BackgroundTaskServer");
        }
    }

    @Override
    public boolean signalBackgroundTaskServerAlive(BackgroundTaskServerStatus serverStatus) {
        final UpdateResult updateResult = this.backgroundTaskServerCollection.updateOne(eq(toMongoId(BackgroundTaskServers.FIELD_ID), serverStatus.getId()), backgroundTaskServerStatusDocumentMapper.toUpdateDocument(serverStatus));
        if (updateResult.getModifiedCount() < 1) {
            throw new ServerTimedOutException(serverStatus, new StorageException("BackgroundTaskServer with id " + serverStatus.getId() + " was not found"));
        }
        final Document document = this.backgroundTaskServerCollection.find(eq(toMongoId(Tasks.FIELD_ID), serverStatus.getId())).projection(include(BackgroundTaskServers.FIELD_IS_RUNNING)).first();
        return document != null && document.getBoolean(BackgroundTaskServers.FIELD_IS_RUNNING);
    }

    @Override
    public void signalBackgroundTaskServerStopped(BackgroundTaskServerStatus serverStatus) {
        this.backgroundTaskServerCollection.deleteOne(eq(toMongoId(BackgroundTaskServers.FIELD_ID), serverStatus.getId()));
    }

    @Override
    public List<BackgroundTaskServerStatus> getBackgroundTaskServers() {
        return this.backgroundTaskServerCollection
                .find()
                .sort(ascending(BackgroundTaskServers.FIELD_FIRST_HEARTBEAT))
                .map(backgroundTaskServerStatusDocumentMapper::toBackgroundTaskServerStatus)
                .into(new ArrayList<>());
    }

    @Override
    public UUID getLongestRunningBackgroundTaskServerId() {
        return this.backgroundTaskServerCollection
                .find()
                .sort(ascending(BackgroundTaskServers.FIELD_FIRST_HEARTBEAT))
                .projection(include(toMongoId(BackgroundTaskServers.FIELD_ID)))
                .map(MongoUtils::getIdAsUUID)
                .first();
    }

    @Override
    public int removeTimedOutBackgroundTaskServers(Instant heartbeatOlderThan) {
        final DeleteResult deleteResult = this.backgroundTaskServerCollection.deleteMany(
                lt(BackgroundTaskServers.FIELD_LAST_HEARTBEAT, heartbeatOlderThan)
        );
        return (int) deleteResult.getDeletedCount();
    }

    @Override
    public void saveMetadata(CarrotMetadata metadata) {
        metadataCollection.updateOne(eq(toMongoId(Metadata.FIELD_ID), metadata.getId()), metadataDocumentMapper.toUpdateDocument(metadata), new UpdateOptions().upsert(true));
        notifyMetadataChangeListeners();
    }

    @Override
    public List<CarrotMetadata> getMetadata(String name) {
        return metadataCollection.find(eq(Metadata.FIELD_NAME, name))
                .map(metadataDocumentMapper::toCarrotMetadata)
                .into(new ArrayList<>());
    }

    @Override
    public CarrotMetadata getMetadata(String name, String owner) {
        Document document = metadataCollection.find(eq(toMongoId(Metadata.FIELD_ID), toId(name, owner))).first();
        return metadataDocumentMapper.toCarrotMetadata(document);
    }

    @Override
    public void deleteMetadata(String name) {
        final DeleteResult deleteResult = metadataCollection.deleteMany(eq(Metadata.FIELD_NAME, name));
        long deletedCount = deleteResult.getDeletedCount();
        notifyMetadataChangeListeners(deletedCount > 0);
    }

    @Override
    public Task save(Task task) {
        Integer partition = null;
        if (task.getMetadata() != null) {
            partition = (Integer) task.getMetadata().get(PARTITION_HINT_KEY);
        }
        if (partition == null) {
            partition = partitioner.partition(task);
        }
        return this.saveByPartition(task, partition);
    }

    @Override
    public Task saveByPartition(Task task, Integer partition) {
        MongoCollection<Document> taskCollection = validateThenGetTaskPartition(partition);
        try (TaskVersioner taskVersioner = new TaskVersioner(task)) {
            if (taskVersioner.isNewTask()) {
                taskCollection.insertOne(taskDocumentMapper.toInsertDocument(task));
            } else {
                final UpdateOneModel<Document> updateModel = taskDocumentMapper.toUpdateOneModel(task);
                final UpdateResult updateResult = taskCollection.updateOne(updateModel.getFilter(), updateModel.getUpdate());
                if (updateResult.getModifiedCount() < 1) {
                    throw new ConcurrentTaskModificationException(task);
                }
            }
            taskVersioner.commitVersion();
        } catch (MongoWriteException e) {
            if (e.getError().getCode() == 11000) throw new ConcurrentTaskModificationException(task);
            throw new StorageException(e);
        } catch (MongoException e) {
            throw new StorageException(e);
        }
        notifyTaskStatsOnChangeListeners();
        return task;
    }

    @Override
    public int deletePermanentlyByPartition(UUID id, Integer partition) {
        MongoCollection<Document> taskCollection = validateThenGetTaskPartition(partition);
        final DeleteResult result = taskCollection.deleteOne(eq(toMongoId(Tasks.FIELD_ID), id));
        final int deletedCount = (int) result.getDeletedCount();
        notifyTaskStatsOnChangeListenersIf(deletedCount > 0);
        return deletedCount;
    }

    @Override
    public Task getTaskById(UUID id) {
        for (int p = 0; p < totalNumOfPartitions; p++) {
            MongoCollection<Document> taskCollection = validateThenGetTaskPartition(p);
            final Document document = taskCollection.find(eq(toMongoId(Tasks.FIELD_ID), id)).projection(include(Tasks.FIELD_TASK_AS_JSON)).first();
            if (document != null) {
                Task task = taskDocumentMapper.toTask(document);
                // add partition hint, so it can be saved into the same partition later
                task.getMetadata().put(PARTITION_HINT_KEY, p);
                return task;
            }
        }
        throw new TaskNotFoundException(id);
    }

    @Override
    public List<Task> save(List<Task> tasks) {
        Integer partition = this.partitioner.partition(tasks.get(0));
        return this.saveByPartition(tasks, partition);
    }

    @Override
    public List<Task> saveByPartition(List<Task> tasks, Integer partition) {
        MongoCollection<Document> taskCollection = validateThenGetTaskPartition(partition);
        try (TaskListVersioner taskListVersioner = new TaskListVersioner(tasks)) {
            if (taskListVersioner.areNewTasks()) {
                final List<Document> tasksToInsert = tasks.stream()
                        .map(task -> taskDocumentMapper.toInsertDocument(task))
                        .collect(Collectors.toList());
                taskCollection.insertMany(tasksToInsert);
            } else {
                final List<WriteModel<Document>> tasksToUpdate = tasks.stream()
                        .map(task -> taskDocumentMapper.toUpdateOneModel(task))
                        .collect(Collectors.toList());
                final BulkWriteResult bulkWriteResult = taskCollection.bulkWrite(tasksToUpdate);
                if (bulkWriteResult.getModifiedCount() != tasks.size()) {
                    // ugly workaround as we do not know which document did not update dure to concurrent modification exception. So, we download them all and compare the lastUpdated
                    final Map<UUID, Task> mongoDbDocuments = new HashMap<>();
                    taskCollection
                            .find(in(toMongoId(Tasks.FIELD_ID), tasks.stream().map(Task::getId).collect(Collectors.toList())))
                            .projection(include(Tasks.FIELD_TASK_AS_JSON))
                            .map(taskDocumentMapper::toTask)
                            .forEach(task -> mongoDbDocuments.put(task.getId(), task));

                    final List<Task> concurrentModificationTasks = tasks.stream()
                            .filter(task -> !task.getUpdatedAt().equals(mongoDbDocuments.get(task.getId()).getUpdatedAt()))
                            .collect(Collectors.toList());
                    taskListVersioner.rollbackVersions(concurrentModificationTasks);
                    throw new ConcurrentTaskModificationException(concurrentModificationTasks);
                }
            }
            taskListVersioner.commitVersions();
        } catch (MongoException e) {
            throw new StorageException(e);
        }
        notifyTaskStatsOnChangeListenersIf(!tasks.isEmpty());
        return tasks;
    }

    @Override
    public List<Task> getTasksByPartition(StateName state, Instant updatedBefore, PageRequest pageRequest, Integer partition) {
        return findTasks(and(eq(Tasks.FIELD_STATE, state.name()), lt(Tasks.FIELD_UPDATED_AT, toMicroSeconds(updatedBefore))), pageRequest, partition);
    }

    @Override
    public List<Task> getScheduledTasksByPartition(Instant scheduledBefore, PageRequest pageRequest, Integer partition) {
        return findTasks(and(eq(Tasks.FIELD_STATE, SCHEDULED.name()), lt(Tasks.FIELD_SCHEDULED_AT, toMicroSeconds(scheduledBefore))), pageRequest, partition);
    }

    @Override
    public List<Task> getTasksByPartition(StateName state, PageRequest pageRequest, Integer partition) {
        return findTasks(eq(Tasks.FIELD_STATE, state.name()), pageRequest, partition);
    }

    @Override
    public Page<Task> getTaskPageByPartition(StateName state, PageRequest pageRequest, Integer partition) {
        return getTaskPageByPartition(eq(Tasks.FIELD_STATE, state.name()), pageRequest, partition);
    }

    @Override
    public int deleteTasksPermanentlyByPartition(StateName state, Instant updatedBefore, Integer partition) {
        MongoCollection<Document> taskCollection = validateThenGetTaskPartition(partition);
        final DeleteResult deleteResult = taskCollection.deleteMany(and(eq(Tasks.FIELD_STATE, state.name()), lt(Tasks.FIELD_CREATED_AT, toMicroSeconds(updatedBefore))));
        final long deletedCount = deleteResult.getDeletedCount();
        notifyTaskStatsOnChangeListenersIf(deletedCount > 0);
        return (int) deletedCount;
    }

    @Override
    public Set<String> getDistinctTaskSignatures(StateName... states) {
        Set<String> resultSet = new HashSet<>();
        for (int p = 0; p < totalNumOfPartitions; p++) {
            Set<String> sigSet = getDistinctTaskSignaturesByPartition(p, states);
            resultSet.addAll(sigSet);
        }
        return resultSet;
    }

    private Set<String> getDistinctTaskSignaturesByPartition(Integer partition, StateName... states) {
        MongoCollection<Document> taskCollection = validateThenGetTaskPartition(partition);
        return taskCollection
                .distinct(Tasks.FIELD_TASK_SIGNATURE, in(Tasks.FIELD_STATE, stream(states).map(Enum::name).collect(Collectors.toSet())), String.class)
                .into(new HashSet<>());
    }

    @Override
    public boolean existsByPartition(TaskDetails taskDetails, Integer partition, StateName... states) {
        MongoCollection<Document> taskCollection = validateThenGetTaskPartition(partition);
        return taskCollection.countDocuments(and(in(Tasks.FIELD_STATE, stream(states).map(Enum::name).collect(Collectors.toSet())), eq(Tasks.FIELD_TASK_SIGNATURE, getTaskSignature(taskDetails)))) > 0;
    }

    @Override
    public boolean recurringTaskExistsByPartition(String recurringTaskId, Integer partition, StateName... states) {
        MongoCollection<Document> taskCollection = validateThenGetTaskPartition(partition);
        return taskCollection.countDocuments(and(in(Tasks.FIELD_STATE, stream(states).map(Enum::name).collect(Collectors.toSet())), eq(Tasks.FIELD_RECURRING_TASK_ID, recurringTaskId))) > 0;
    }

    @Override
    public RecurringTask saveRecurringTask(RecurringTask recurringTask) {
        Integer partition = partitioner.partition(recurringTask);
        return this.saveRecurringTaskByPartition(recurringTask, partition);
    }

    public RecurringTask saveRecurringTaskByPartition(RecurringTask recurringTask, Integer partition) {
        MongoCollection<Document> recurringTaskCollection = validateThenGetRecurringTaskPartition(partition);
        recurringTaskCollection.replaceOne(eq(toMongoId(Tasks.FIELD_ID), recurringTask.getId()), taskDocumentMapper.toInsertDocument(recurringTask), new ReplaceOptions().upsert(true));
        return recurringTask;
    }

    @Override
    public List<RecurringTask> getRecurringTasksByPartition(Integer partition) {
        MongoCollection<Document> recurringTaskCollection = validateThenGetRecurringTaskPartition(partition);
        return recurringTaskCollection.find().map(taskDocumentMapper::toRecurringTask).into(new ArrayList<>());
    }

    @Override
    public List<RecurringTask> getRecurringTasks() {
        List<RecurringTask> results = new ArrayList<>();
        for(int p = 0; p < totalNumOfPartitions; p++) {
            List<RecurringTask> recurringTasks = getRecurringTasksByPartition(p);
            results.addAll(recurringTasks);
        }
        return results;
    }

    @Override
    public long countRecurringTasksByPartition(Integer partition) {
        MongoCollection<Document> recurringTaskCollection = validateThenGetRecurringTaskPartition(partition);
        return recurringTaskCollection.countDocuments();
    }

    @Override
    public int deleteRecurringTask(String id) {
        for(int p = 0; p < this.totalNumOfPartitions; p++) {
            int count = deleteRecurringTaskByPartition(id, p);
            if (count > 0) return count;
        }
        return 0;
    }

    private int deleteRecurringTaskByPartition(String id, Integer partition) {
        MongoCollection<Document> recurringTaskCollection = validateThenGetRecurringTaskPartition(partition);
        final DeleteResult deleteResult = recurringTaskCollection.deleteOne(eq(toMongoId(Tasks.FIELD_ID), id));
        return (int) deleteResult.getDeletedCount();
    }

    @Override
    public TaskStatsData getTaskStatsData() {
        TaskStatsData data = new TaskStatsData();

        TaskStats taskStats0 = calculateTaskStats(0);
        data.getTaskStatsList().add(taskStats0);

        final Instant instant = taskStats0.getTimeStamp();
        Long scheduledCount = taskStats0.getScheduled();
        Long enqueuedCount = taskStats0.getEnqueued();
        Long processingCount = taskStats0.getProcessing();
        Long succeededCount = taskStats0.getSucceeded();
        Long failedCount = taskStats0.getFailed();
        Long deletedCount = taskStats0.getDeleted();
        final Long allTimeSucceededCount = taskStats0.getAllTimeSucceeded();
        Long total = taskStats0.getTotal();
        int recurringTaskCount = taskStats0.getRecurringTasks();
        final int backgroundTaskServerCount = taskStats0.getBackgroundTaskServers();

        for(int p = 1; p < totalNumOfPartitions; p++) {
            TaskStats taskStats = calculateTaskStats(p);
            data.getTaskStatsList().add(taskStats);

            scheduledCount += taskStats.getScheduled();
            enqueuedCount += taskStats.getEnqueued();
            processingCount += taskStats.getProcessing();
            succeededCount += taskStats.getSucceeded();
            failedCount += taskStats.getFailed();
            deletedCount += taskStats.getDeleted();
            total += taskStats.getTotal();
            recurringTaskCount += taskStats.getRecurringTasks();
        }

        TaskStats overallTaskStats = new TaskStats(
                instant,
                total,
                scheduledCount,
                enqueuedCount,
                processingCount,
                failedCount,
                succeededCount,
                allTimeSucceededCount,
                deletedCount,
                recurringTaskCount,
                backgroundTaskServerCount
        );
        data.setOverallTaskStats(overallTaskStats);

        return data;
    }

    private TaskStats calculateTaskStats(Integer partition) {
        Instant instant = Instant.now();
        final Document succeededTaskStats = metadataCollection.find(eq(toMongoId(Metadata.FIELD_ID), Metadata.STATS_ID)).first();
        final long allTimeSucceededCount = (succeededTaskStats != null ? ((Number) succeededTaskStats.get(Metadata.FIELD_VALUE)).longValue() : 0L);

        MongoCollection<Document> taskCollection = validateThenGetTaskPartition(partition);
        final List<Document> aggregates = taskCollection.aggregate(asList(
                match(ne(Tasks.FIELD_STATE, null)),
                group("$state", Accumulators.sum(Tasks.FIELD_STATE, 1)),
                limit(10)
        )).into(new ArrayList<>());

        Long scheduledCount = getCount(SCHEDULED, aggregates);
        Long enqueuedCount = getCount(ENQUEUED, aggregates);
        Long processingCount = getCount(PROCESSING, aggregates);
        Long succeededCount = getCount(SUCCEEDED, aggregates);
        Long failedCount = getCount(FAILED, aggregates);
        Long deletedCount = getCount(DELETED, aggregates);

        final long total = scheduledCount + enqueuedCount + processingCount + succeededCount + failedCount;
        MongoCollection<Document> recurringTaskCollection = validateThenGetRecurringTaskPartition(partition);
        final int recurringTaskCount = (int) recurringTaskCollection.countDocuments();
        final int backgroundTaskServerCount = (int) backgroundTaskServerCollection.countDocuments();

        return new TaskStats(
                instant,
                total,
                scheduledCount,
                enqueuedCount,
                processingCount,
                failedCount,
                succeededCount,
                allTimeSucceededCount,
                deletedCount,
                recurringTaskCount,
                backgroundTaskServerCount
        );
    }

    @Override
    public void publishTotalAmountOfSucceededTasks(int amount) {
        metadataCollection.updateOne(eq(toMongoId(Metadata.FIELD_ID), Metadata.STATS_ID), Updates.inc(Metadata.FIELD_VALUE, amount), new UpdateOptions().upsert(true));
    }

    private void validateMongoClient(MongoClient mongoClient) {
        Optional<Method> codecRegistryGetter = findMethod(mongoClient, "getCodecRegistry");
        if (codecRegistryGetter.isPresent()) {
            try {
                CodecRegistry codecRegistry = (CodecRegistry) codecRegistryGetter.get().invoke(mongoClient);
                UuidCodec uuidCodec = (UuidCodec) codecRegistry.get(UUID.class);
                if (UuidRepresentation.UNSPECIFIED == uuidCodec.getUuidRepresentation()) {
                    throw new StorageException("\n" +
                            "Since release 4.0.0 of the MongoDB Java Driver, the default BSON representation of java.util.UUID values has changed from JAVA_LEGACY to UNSPECIFIED.\n" +
                            "Applications that store or retrieve UUID values must explicitly specify which representation to use, via the uuidRepresentation property of MongoClientSettings.\n" +
                            "The good news is that Carrot works both with the STANDARD as the JAVA_LEGACY uuidRepresentation. Please choose the one most appropriate for your application.");
                }
            } catch (IllegalAccessException | InvocationTargetException e) {
                throw shouldNotHappenException(e);
            }
        }
    }

    private void runMigrations() {
        new MongoDBCreator(mongoClient, databaseName, collectionPrefix, totalNumOfPartitions).runMigrations();
    }

    private void validateTables() {
        new MongoDBCreator(mongoClient, databaseName, collectionPrefix, totalNumOfPartitions).validateCollections();
    }

    private Page<Task> getTaskPageByPartition(Bson query, PageRequest pageRequest, Integer partition) {
        MongoCollection<Document> taskCollection = validateThenGetTaskPartition(partition);
        long count = taskCollection.countDocuments(query);
        if (count > 0) {
            List<Task> tasks = findTasks(query, pageRequest, partition);
            return new Page<>(count, tasks, pageRequest);
        }
        return new Page(0, new ArrayList<>(), pageRequest);
    }

    private List<Task> findTasks(Bson query, PageRequest pageRequest, Integer partition) {
        MongoCollection<Document> taskCollection = validateThenGetTaskPartition(partition);
        return taskCollection
                .find(query)
                .sort(pageRequestMapper.map(pageRequest))
                .skip((int) pageRequest.getOffset())
                .limit(pageRequest.getLimit())
                .projection(include(Tasks.FIELD_TASK_AS_JSON))
                .map(taskDocumentMapper::toTask)
                .into(new ArrayList<>());
    }

    private MongoCollection<Document> validateThenGetTaskPartition(Integer partition) {
        MongoCollection<Document> taskCollection = taskCollectionMap.get(partition);
        if (taskCollection == null)
            throw new IllegalArgumentException(String.format(ERR_MSG_INVALID_PARTITION, partition));
        return taskCollection;
    }

    private MongoCollection<Document> validateThenGetRecurringTaskPartition(Integer partition) {
        MongoCollection<Document> recurringTaskCollection = recurringTaskCollectionMap.get(partition);
        if (recurringTaskCollection == null)
            throw new IllegalArgumentException(String.format(ERR_MSG_INVALID_PARTITION, partition));
        return recurringTaskCollection;
    }

    private long toMicroSeconds(Instant instant) {
        return ChronoUnit.MICROS.between(Instant.EPOCH, instant);
    }

    private Long getCount(StateName stateName, List<Document> aggregates) {
        Predicate<Document> statePredicate = document -> stateName.name().equals(document.get(toMongoId(Tasks.FIELD_ID)));
        BiFunction<Optional<Document>, Integer, Integer> count = (document, defaultValue) -> document.map(doc -> doc.getInteger(Tasks.FIELD_STATE)).orElse(defaultValue);
        long aggregateCount = count.apply(aggregates.stream().filter(statePredicate).findFirst(), 0);
        return aggregateCount;
    }

    public static String toMongoId(String id) {
        return "_" + id;
    }
}
