package cn.boboweike.carrot.server;

import cn.boboweike.carrot.CarrotException;
import cn.boboweike.carrot.SevereCarrotException;
import cn.boboweike.carrot.server.concurrent.ConcurrentTaskModificationResolver;
import cn.boboweike.carrot.server.concurrent.UnresolvableConcurrentTaskModificationException;
import cn.boboweike.carrot.server.dashboard.DashboardNotificationManager;
import cn.boboweike.carrot.server.strategy.WorkDistributionStrategy;
import cn.boboweike.carrot.storage.BackgroundTaskServerStatus;
import cn.boboweike.carrot.storage.ConcurrentTaskModificationException;
import cn.boboweike.carrot.storage.PageRequest;
import cn.boboweike.carrot.storage.PartitionedStorageProvider;
import cn.boboweike.carrot.tasks.RecurringTask;
import cn.boboweike.carrot.tasks.Task;
import cn.boboweike.carrot.tasks.filters.TaskFilterUtils;
import cn.boboweike.carrot.tasks.states.StateName;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Consumer;
import java.util.function.Supplier;

import static cn.boboweike.carrot.server.BackgroundTaskServer.NO_PARTITION;
import static cn.boboweike.carrot.storage.PageRequest.ascOnUpdatedAt;
import static cn.boboweike.carrot.tasks.states.StateName.PROCESSING;
import static cn.boboweike.carrot.tasks.states.StateName.SUCCEEDED;
import static java.time.Duration.ofSeconds;
import static java.time.Instant.now;
import static java.util.Collections.emptyList;
import static java.util.stream.Collectors.toList;

public class TaskZooKeeper implements Runnable {
    static final Logger LOGGER = LoggerFactory.getLogger(TaskZooKeeper.class);

    private final BackgroundTaskServer backgroundTaskServer;
    private final PartitionedStorageProvider storageProvider;
    private final List<RecurringTask> recurringTasks;
    private final DashboardNotificationManager dashboardNotificationManager;
    private final TaskFilterUtils taskFilterUtils;
    private final WorkDistributionStrategy workDistributionStrategy;
    private final ConcurrentTaskModificationResolver concurrentTaskModificationResolver;
    private final Map<Task, Thread> currentlyProcessedTasks;
    private final AtomicInteger exceptionCount;
    private final ReentrantLock reentrantLock;
    private final AtomicInteger occupiedWorkers;
    private final Duration durationPollIntervalTimeBox;
    private Instant runStartTime;

    public TaskZooKeeper(BackgroundTaskServer backgroundTaskServer) {
        this.backgroundTaskServer = backgroundTaskServer;
        this.storageProvider = backgroundTaskServer.getStorageProvider();
        this.recurringTasks = new ArrayList<>();
        this.workDistributionStrategy = backgroundTaskServer.getWorkDistributionStrategy();
        this.dashboardNotificationManager = backgroundTaskServer.getDashboardNotificationManager();
        this.taskFilterUtils = new TaskFilterUtils(backgroundTaskServer.getTaskFilters());
        this.concurrentTaskModificationResolver = createConcurrentTaskModificationResolver();
        this.currentlyProcessedTasks = new ConcurrentHashMap<>();
        this.durationPollIntervalTimeBox = Duration.ofSeconds((long) (backgroundTaskServerStatus().getPollIntervalInSeconds() - (backgroundTaskServerStatus().getPollIntervalInSeconds() * 0.05)));
        this.reentrantLock = new ReentrantLock();
        this.exceptionCount = new AtomicInteger();
        this.occupiedWorkers = new AtomicInteger();
    }

    @Override
    public void run() {
        try {
            runStartTime = Instant.now();
            if (backgroundTaskServer.isUnAnnounced()) return;

            updateTasksThatAreBeingProcessed();
            runRoutineTasks();
            onboardNewWorkIfPossible();
        } catch (Exception e) {
            dashboardNotificationManager.handle(e);
            exceptionCount.getAndIncrement();
            LOGGER.warn(CarrotException.SHOULD_NOT_HAPPEN_MESSAGE +
                    " - Processing will continue, exceptionCount = {}", exceptionCount, e);
        }
    }

    void updateTasksThatAreBeingProcessed() {
        LOGGER.debug("Updating currently processed tasks... ");
        processTaskList(new ArrayList<>(currentlyProcessedTasks.keySet()), this::updateCurrentlyProcessingTask);
    }

    void runRoutineTasks() {
        checkForRecurringTasks();
        checkForScheduledTasks();
        checkForOrphanedTasks();
        checkForSucceededTasksThanCanGoToDeletedState();
        checkForTasksThatCanBeDeleted();
    }

    boolean canOnboardNewWork() {
        return backgroundTaskServerStatus().isRunning() && workDistributionStrategy.canOnboardNewWork();
    }

    void checkForRecurringTasks() {
        Integer partition = getPartition();
        if (partition == NO_PARTITION) return;
        LOGGER.debug("Looking for recurring tasks... ");
        List<RecurringTask> recurringTasks = getRecurringTasks(partition);
        processRecurringTasks(recurringTasks);
    }

    void checkForScheduledTasks() {
        Integer partition = getPartition();
        if (partition == NO_PARTITION) return;
        LOGGER.debug("Looking for scheduled tasks... ");
        Supplier<List<Task>> scheduledTasksSupplier = () ->
                storageProvider.getScheduledTasksByPartition(now().plusSeconds(backgroundTaskServerStatus().getPollIntervalInSeconds()),
                        ascOnUpdatedAt(1000),
                        partition);
        processTaskList(scheduledTasksSupplier, Task::enqueue);
    }

    void checkForOrphanedTasks() {
        Integer partition = getPartition();
        if (partition == NO_PARTITION) return;
        LOGGER.debug("Looking for orphan tasks... ");
        final Instant updatedBefore = runStartTime.minus(ofSeconds(backgroundTaskServer.getServerStatus().getPollIntervalInSeconds()).multipliedBy(4));
        Supplier<List<Task>> orphanedTasksSupplier = () -> storageProvider.getTasksByPartition(PROCESSING, updatedBefore, ascOnUpdatedAt(1000), partition);
        processTaskList(orphanedTasksSupplier, task -> task.failed("Orphaned task", new IllegalThreadStateException("Task was too long in PROCESSING state without being updated.")));
    }

    void checkForSucceededTasksThanCanGoToDeletedState() {
        Integer partition = getPartition();
        if (partition == NO_PARTITION) return;
        LOGGER.debug("Looking for succeeded tasks that can go to the deleted state... ");
        AtomicInteger succeededTasksCounter = new AtomicInteger();

        final Instant updatedBefore = now().minus(backgroundTaskServer.getServerStatus().getDeleteSucceededTasksAfter());
        Supplier<List<Task>> succeededTasksSupplier = () -> storageProvider.getTasksByPartition(SUCCEEDED, updatedBefore, ascOnUpdatedAt(1000), partition);
        processTaskList(succeededTasksSupplier, task -> {
            succeededTasksCounter.incrementAndGet();
            task.delete("Carrot maintenance - deleting succeeded task");
        });

        if (succeededTasksCounter.get() > 0) {
            storageProvider.publishTotalAmountOfSucceededTasks(succeededTasksCounter.get());
        }
    }

    void checkForTasksThatCanBeDeleted() {
        Integer partition = getPartition();
        if (partition == NO_PARTITION) return;
        LOGGER.debug("Looking for deleted tasks that can be deleted permanently... ");
        storageProvider.deleteTasksPermanentlyByPartition(StateName.DELETED, now().minus(backgroundTaskServer.getServerStatus().getPermanentlyDeleteDeletedTasksAfter()), partition);
    }

    private Integer getPartition() {
        Integer partition =  this.backgroundTaskServer.getPartition();
        if (partition == null) return NO_PARTITION;
        return partition;
    }

    void onboardNewWorkIfPossible() {
        if (pollIntervalInSecondsTimeBoxIsAboutToPass()) return;
        if (canOnboardNewWork()) {
            checkForEnqueuedTasks();
        }
    }

    void checkForEnqueuedTasks() {
        Integer partition = getPartition();
        if (partition == NO_PARTITION) return;
        try {
            if (reentrantLock.tryLock()) {
                LOGGER.debug("Looking for enqueued tasks... ");
                final PageRequest workPageRequest = workDistributionStrategy.getWorkPageRequest();
                if (workPageRequest.getLimit() > 0) {
                    final List<Task> enqueuedTasks = storageProvider.getTasksByPartition(StateName.ENQUEUED, workPageRequest, partition);
                    enqueuedTasks.forEach(backgroundTaskServer::processTask);
                }
            }
        } finally {
            if (reentrantLock.isHeldByCurrentThread()) {
                reentrantLock.unlock();
            }
        }
    }

    void processRecurringTasks(List<RecurringTask> recurringTasks) {
        Integer partition = getPartition();
        if (partition == NO_PARTITION) return;
        LOGGER.debug("Found {} recurring tasks", recurringTasks.size());
        List<Task> tasksToSchedule = recurringTasks.stream()
                .filter(rt -> mustSchedule(rt, partition))
                .map(RecurringTask::toScheduledTask)
                .collect(toList());
        if(!tasksToSchedule.isEmpty()) {
            storageProvider.saveByPartition(tasksToSchedule, partition);
        }
    }

    boolean mustSchedule(RecurringTask recurringTask, Integer partition) {
        return recurringTask.getNextRun().isBefore(now().plus(durationPollIntervalTimeBox).plusSeconds(1))
                && !storageProvider.recurringTaskExistsByPartition(recurringTask.getId(), partition, StateName.SCHEDULED, StateName.ENQUEUED, StateName.PROCESSING);

    }

    void processTaskList(Supplier<List<Task>> taskListSupplier, Consumer<Task> taskConsumer) {
        List<Task> tasks = getTasksToProcess(taskListSupplier);
        while (!tasks.isEmpty()) {
            processTaskList(tasks, taskConsumer);
            tasks = getTasksToProcess(taskListSupplier);
        }
    }

    void processTaskList(List<Task> tasks, Consumer<Task> taskConsumer) {
        Integer partition = getPartition();
        if (partition == NO_PARTITION) return;
        if (!tasks.isEmpty()) {
            try {
                tasks.forEach(taskConsumer);
                taskFilterUtils.runOnStateElectionFilter(tasks);
                storageProvider.saveByPartition(tasks, partition);
                taskFilterUtils.runOnStateAppliedFilters(tasks);
            } catch (ConcurrentTaskModificationException concurrentTaskModificationException) {
                try {
                    concurrentTaskModificationResolver.resolve(concurrentTaskModificationException);
                } catch (UnresolvableConcurrentTaskModificationException unresolvableConcurrentTaskModificationException) {
                    throw new SevereCarrotException("Could not resolve ConcurrentTaskModificationException", unresolvableConcurrentTaskModificationException);
                }
            }
        }
    }

    BackgroundTaskServerStatus backgroundTaskServerStatus() {
        return backgroundTaskServer.getServerStatus();
    }

    public void startProcessing(Task task, Thread thread) {
        currentlyProcessedTasks.put(task, thread);
    }

    public void stopProcessing(Task task) {
        currentlyProcessedTasks.remove(task);
    }

    public Thread getThreadProcessingTask(Task task) {
        return currentlyProcessedTasks.get(task);
    }

    public int getOccupiedWorkerCount() {
        return occupiedWorkers.get();
    }

    public void notifyThreadOccupied() {
        occupiedWorkers.incrementAndGet();
    }

    public void notifyThreadIdle() {
        this.occupiedWorkers.decrementAndGet();
        if (workDistributionStrategy.canOnboardNewWork()) {
            checkForEnqueuedTasks();
        }
    }

    private List<Task> getTasksToProcess(Supplier<List<Task>> taskListSupplier) {
        if (pollIntervalInSecondsTimeBoxIsAboutToPass()) return emptyList();
        return taskListSupplier.get();
    }

    private void updateCurrentlyProcessingTask(Task task) {
        try {
            task.updateProcessing();
        } catch (ClassCastException e) {
            // why: because of thread context switching there is a tiny chance that the task has succeeded
        }
    }

    private boolean pollIntervalInSecondsTimeBoxIsAboutToPass() {
        final Duration durationRunTime = Duration.between(runStartTime, now());
        final boolean runTimeBoxIsPassed = durationRunTime.compareTo(durationPollIntervalTimeBox) >= 0;
        if (runTimeBoxIsPassed) {
            LOGGER.debug("Carrot is passing the poll interval in seconds timebox because of too many tasks.");
        }
        return runTimeBoxIsPassed;
    }

    private List<RecurringTask> getRecurringTasks(Integer partition) {
        if(this.recurringTasks.size() != storageProvider.countRecurringTasksByPartition(partition)) {
            this.recurringTasks.clear();
            this.recurringTasks.addAll(storageProvider.getRecurringTasksByPartition(partition));
        }
        return this.recurringTasks;
    }

    ConcurrentTaskModificationResolver createConcurrentTaskModificationResolver() {
        return backgroundTaskServer.getConfiguration()
                .concurrentTaskModificationPolicy.toConcurrentTaskModificationResolver(storageProvider, this);
    }
}
