package cn.boboweike.carrot.storage;

import cn.boboweike.carrot.lock.LockProvider;
import cn.boboweike.carrot.lock.inmemory.InMemoryLockProvider;
import cn.boboweike.carrot.scheduling.partition.Partitioner;
import cn.boboweike.carrot.tasks.*;
import cn.boboweike.carrot.tasks.mappers.TaskMapper;
import cn.boboweike.carrot.tasks.states.ScheduledState;
import cn.boboweike.carrot.tasks.states.StateName;
import cn.boboweike.carrot.utils.resilience.RateLimiter;

import java.time.Instant;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static cn.boboweike.carrot.storage.StorageProviderUtils.Metadata.*;
import static cn.boboweike.carrot.storage.StorageProviderUtils.Tasks.FIELD_CREATED_AT;
import static cn.boboweike.carrot.storage.StorageProviderUtils.Tasks.FIELD_UPDATED_AT;
import static cn.boboweike.carrot.storage.StorageProviderUtils.returnConcurrentModifiedTasks;
import static cn.boboweike.carrot.utils.TaskUtils.getTaskSignature;
import static cn.boboweike.carrot.utils.reflection.ReflectionUtils.getValueFromFieldOrProperty;
import static cn.boboweike.carrot.utils.reflection.ReflectionUtils.setFieldUsingAutoboxing;
import static cn.boboweike.carrot.utils.resilience.RateLimiter.Builder.rateLimit;
import static cn.boboweike.carrot.utils.resilience.RateLimiter.SECOND;
import static java.lang.Long.parseLong;
import static java.util.Arrays.asList;

public class InMemoryPartitionedStorageProvider extends AbstractPartitionedStorageProvider {
    private final Map<UUID, Task> taskQueue = new ConcurrentHashMap<>();
    private final Map<UUID, BackgroundTaskServerStatus> backgroundTaskServers = new ConcurrentHashMap<>();
    private final List<RecurringTask> recurringTasks = new CopyOnWriteArrayList<>();
    private final Map<String, CarrotMetadata> metadata = new ConcurrentHashMap<>();
    private final LockProvider lockProvider = new InMemoryLockProvider();
    private TaskMapper taskMapper;

    public InMemoryPartitionedStorageProvider() {
        this(rateLimit().at1Request().per(SECOND));
    }

    protected InMemoryPartitionedStorageProvider(RateLimiter rateLimiter) {
        super(rateLimiter);
        publishTotalAmountOfSucceededTasks(0);
    }

    @Override
    public int getTotalNumOfPartitions() {
        return 1; // only 1 partition for in memory StorageProvider
    }

    @Override
    public boolean lockByPartition(Integer partition, int durationInSeconds, String lockedBy) {
        return this.lockProvider.lock(PARTITION_PREFIX + partition, durationInSeconds, lockedBy);
    }

    @Override
    public boolean extendLockByPartition(Integer partition, int durationInSeconds, String lockedBy) {
        return this.lockProvider.extend(PARTITION_PREFIX + partition, durationInSeconds, lockedBy);
    }

    @Override
    public boolean unlockByPartition(Integer partition) {
        return this.lockProvider.unlock(PARTITION_PREFIX + partition);
    }

    @Override
    public void setTaskMapper(TaskMapper taskMapper) {
        this.taskMapper = taskMapper;
    }

    @Override
    public void setPartitioner(Partitioner partitioner) {

    }

    @Override
    public void setUpStorageProvider(StorageProviderUtils.DatabaseOptions databaseOptions) {

    }

    @Override
    public void announceBackgroundTaskServer(BackgroundTaskServerStatus serverStatus) {
        final BackgroundTaskServerStatus backgroundTaskServerStatus = new BackgroundTaskServerStatus(
                serverStatus.getId(),
                serverStatus.getWorkerPoolSize(),
                serverStatus.getPollIntervalInSeconds(),
                serverStatus.getDeleteSucceededTasksAfter(),
                serverStatus.getPermanentlyDeleteDeletedTasksAfter(),
                serverStatus.getFirstHeartbeat(),
                serverStatus.getLastHeartbeat(),
                serverStatus.isRunning(),
                serverStatus.getSystemTotalMemory(),
                serverStatus.getSystemFreeMemory(),
                serverStatus.getSystemCpuLoad(),
                serverStatus.getProcessMaxMemory(),
                serverStatus.getProcessFreeMemory(),
                serverStatus.getProcessAllocatedMemory(),
                serverStatus.getProcessCpuLoad(),
                serverStatus.getPartition()
        );
        backgroundTaskServers.put(serverStatus.getId(), backgroundTaskServerStatus);
    }

    @Override
    public boolean signalBackgroundTaskServerAlive(BackgroundTaskServerStatus serverStatus) {
        if (!backgroundTaskServers.containsKey(serverStatus.getId()))
            throw new ServerTimedOutException(serverStatus, new StorageException("The server is not there"));

        announceBackgroundTaskServer(serverStatus);
        final BackgroundTaskServerStatus backgroundTaskServerStatus = backgroundTaskServers.get(serverStatus.getId());
        return backgroundTaskServerStatus.isRunning();
    }

    @Override
    public void signalBackgroundTaskServerStopped(BackgroundTaskServerStatus serverStatus) {
        backgroundTaskServers.remove(serverStatus.getId());
    }

    @Override
    public List<BackgroundTaskServerStatus> getBackgroundTaskServers() {
        return backgroundTaskServers.values().stream()
                .sorted(Comparator.comparing(BackgroundTaskServerStatus::getFirstHeartbeat))
                .collect(Collectors.toList());
    }

    @Override
    public UUID getLongestRunningBackgroundTaskServerId() {
        return backgroundTaskServers.values().stream()
                .min(Comparator.comparing(BackgroundTaskServerStatus::getFirstHeartbeat))
                .map(BackgroundTaskServerStatus::getId)
                .orElseThrow(() -> new IllegalStateException("No servers available!"));
    }

    @Override
    public int removeTimedOutBackgroundTaskServers(Instant heartbeatOlderThan) {
        final List<UUID> serversToRemove = backgroundTaskServers.entrySet().stream()
                .filter(entry -> entry.getValue().getLastHeartbeat().isBefore(heartbeatOlderThan))
                .map(Map.Entry::getKey)
                .collect(Collectors.toList());
        backgroundTaskServers.keySet().removeAll(serversToRemove);
        return serversToRemove.size();
    }

    @Override
    public void saveMetadata(CarrotMetadata metadata) {
        this.metadata.put(metadata.getName() + "-" + metadata.getOwner(), metadata);
        notifyMetadataChangeListeners();
    }

    @Override
    public List<CarrotMetadata> getMetadata(String name) {
        return this.metadata.values().stream().filter(m -> m.getName().equals(name)).collect(Collectors.toList());
    }

    @Override
    public CarrotMetadata getMetadata(String name, String owner) {
        return this.metadata.get(name + "-" + owner);
    }

    @Override
    public void deleteMetadata(String name) {
        List<String> metadataToRemove = this.metadata.values().stream()
                .filter(metadata -> metadata.getName().equals(name))
                .map(CarrotMetadata::getId)
                .collect(Collectors.toList());
        if (!metadataToRemove.isEmpty()) {
            this.metadata.keySet().removeAll(metadataToRemove);
            notifyMetadataChangeListeners();
        }
    }

    @Override
    public Task save(Task task) {
        saveTask(task);
        notifyTaskStatsOnChangeListeners();
        return task;
    }

    @Override
    public Task saveByPartition(Task task, Integer partition) {
        return save(task);
    }

    @Override
    public int deletePermanentlyByPartition(UUID id, Integer partition) {
        boolean removed = taskQueue.keySet().remove(id);
        notifyTaskStatsOnChangeListenersIf(removed);
        return removed ? 1 : 0;
    }

    @Override
    public Task getTaskById(UUID id) {
        if (!taskQueue.containsKey(id)) throw new TaskNotFoundException(id);
        return deepClone(taskQueue.get(id));
    }

    @Override
    public List<Task> save(List<Task> tasks) {
        final List<Task> concurrentModifiedTasks = returnConcurrentModifiedTasks(tasks, this::saveTask);
        if (!concurrentModifiedTasks.isEmpty()) {
            throw new ConcurrentTaskModificationException(concurrentModifiedTasks);
        }
        notifyTaskStatsOnChangeListeners();
        return tasks;
    }

    @Override
    public List<Task> saveByPartition(List<Task> tasks, Integer partition) {
        return save(tasks);
    }

    @Override
    public List<Task> getTasksByPartition(StateName state, Instant updatedBefore, PageRequest pageRequest, Integer partition) {
        return getTasksStream(state, pageRequest)
                .filter(task -> task.getUpdatedAt().isBefore(updatedBefore))
                .skip(pageRequest.getOffset())
                .limit(pageRequest.getLimit())
                .map(this::deepClone)
                .collect(Collectors.toList());
    }

    @Override
    public List<Task> getScheduledTasksByPartition(Instant scheduledBefore, PageRequest pageRequest, Integer partition) {
        return getTasksStream(StateName.SCHEDULED, pageRequest)
                .filter(task -> ((ScheduledState) task.getTaskState()).getScheduledAt().isBefore(scheduledBefore))
                .skip(pageRequest.getOffset())
                .limit(pageRequest.getLimit())
                .map(this::deepClone)
                .collect(Collectors.toList());
    }

    @Override
    public List<Task> getTasksByPartition(StateName state, PageRequest pageRequest, Integer partition) {
        return getTasksStream(state, pageRequest)
                .skip(pageRequest.getOffset())
                .limit(pageRequest.getLimit())
                .map(this::deepClone)
                .collect(Collectors.toList());
    }

    @Override
    public Page<Task> getTaskPageByPartition(StateName state, PageRequest pageRequest, Integer partition) {
        return new Page<>(getTasksStream(state).count(), getTasksByPartition(state, pageRequest, partition), pageRequest);
    }

    @Override
    public int deleteTasksPermanentlyByPartition(StateName state, Instant updatedBefore, Integer partition) {
        List<UUID> tasksToRemove = taskQueue.values().stream()
                .filter(task -> task.hasState(state))
                .filter(task -> task.getUpdatedAt().isBefore(updatedBefore))
                .map(Task::getId)
                .collect(Collectors.toList());
        taskQueue.keySet().removeAll(tasksToRemove);
        notifyTaskStatsOnChangeListenersIf(!tasksToRemove.isEmpty());
        return tasksToRemove.size();
    }

    @Override
    public Set<String> getDistinctTaskSignatures(StateName... states) {
        return taskQueue.values().stream()
                .filter(task -> asList(states).contains(task.getState()))
                .map(AbstractTask::getTaskSignature)
                .collect(Collectors.toSet());
    }

    @Override
    public boolean existsByPartition(TaskDetails taskDetails, Integer partition, StateName... states) {
        String actualTaskSignature = getTaskSignature(taskDetails);
        return taskQueue.values().stream()
                .anyMatch(task -> asList(states).contains(task.getState())
                        && actualTaskSignature.equals(getTaskSignature(task.getTaskDetails())));
    }

    @Override
    public boolean recurringTaskExistsByPartition(String recurringTaskId, Integer partition, StateName... states) {
        return taskQueue.values().stream()
                .anyMatch(task ->
                        asList(states).contains(task.getState())
                        && task.getRecurringTaskId()
                                .map(actualRecurringTaskId -> actualRecurringTaskId.equals(recurringTaskId))
                                .orElse(false));
    }

    @Override
    public RecurringTask saveRecurringTask(RecurringTask recurringTask) {
        deleteRecurringTask(recurringTask.getId());
        recurringTasks.add(recurringTask);
        return recurringTask;
    }

    @Override
    public List<RecurringTask> getRecurringTasksByPartition(Integer partition) {
        return recurringTasks;
    }

    @Override
    public List<RecurringTask> getRecurringTasks() {
        return recurringTasks;
    }

    @Override
    public long countRecurringTasksByPartition(Integer partition) {
        return recurringTasks.size();
    }

    @Override
    public int deleteRecurringTask(String id) {
        recurringTasks.removeIf(task -> id.equals(task.getId()));
        return 0;
    }

    @Override
    public TaskStatsData getTaskStatsData() {
        TaskStats taskStats = new TaskStats(
                Instant.now(),
                (long) taskQueue.size(),
                getTasksStream(StateName.SCHEDULED).count(),
                getTasksStream(StateName.ENQUEUED).count(),
                getTasksStream(StateName.PROCESSING).count(),
                getTasksStream(StateName.FAILED).count(),
                getTasksStream(StateName.SUCCEEDED).count(),
                getMetadata(STATS_NAME, STATS_OWNER).getValueAsLong(),
                getTasksStream(StateName.DELETED).count(),
                recurringTasks.size(),
                backgroundTaskServers.size()
        );
        TaskStatsData taskStatsData = new TaskStatsData();
        taskStatsData.setOverallTaskStats(taskStats);
        taskStatsData.getTaskStatsList().add(taskStats);
        return taskStatsData;
    }

    @Override
    public void publishTotalAmountOfSucceededTasks(int amount) {
        CarrotMetadata metadata = this.metadata.computeIfAbsent(STATS_ID, input -> new CarrotMetadata(STATS_NAME, STATS_OWNER, new AtomicLong(0).toString()));
        metadata.setValue(new AtomicLong(parseLong(metadata.getValue()) + amount).toString());
    }

    private Stream<Task> getTasksStream(StateName state, PageRequest pageRequest) {
        return getTasksStream(state)
                .sorted(getTaskComparator(pageRequest));
    }

    private Stream<Task> getTasksStream(StateName state) {
        return taskQueue.values().stream()
                .filter(task -> task.hasState(state));
    }

    private synchronized void saveTask(Task task) {
        final Task oldTask = taskQueue.get(task.getId());
        if (oldTask != null && task.getVersion() != oldTask.getVersion()) {
            throw new ConcurrentTaskModificationException(task);
        }

        try(TaskVersioner taskVersioner = new TaskVersioner(task)) {
            taskQueue.put(task.getId(), deepClone(task));
            taskVersioner.commitVersion();
        }
    }

    private Task deepClone(Task task) {
        final String serializedTaskAsString = taskMapper.serializeTask(task);
        final Task result = taskMapper.deserializeTask(serializedTaskAsString);
        setFieldUsingAutoboxing("locker", result, getValueFromFieldOrProperty(task, "locker"));
        return result;
    }

    private Comparator<Task> getTaskComparator(PageRequest pageRequest) {
        List<Comparator<Task>> result = new ArrayList<>();
        final String[] sortOns = pageRequest.getOrder().split(",");
        for (String sortOn : sortOns) {
            final String[] sortAndOrder = sortOn.split(":");
            String sortField = sortAndOrder[0];
            PageRequest.Order order = PageRequest.Order.ASC;
            if (sortAndOrder.length > 1) {
                order = PageRequest.Order.valueOf(sortAndOrder[1].toUpperCase());
            }
            Comparator<Task> comparator = null;
            if (sortField.equalsIgnoreCase(FIELD_CREATED_AT)) {
                comparator = Comparator.comparing(Task::getCreatedAt);
            } else if (sortField.equalsIgnoreCase(FIELD_UPDATED_AT)) {
                comparator = Comparator.comparing(Task::getUpdatedAt);
            } else {
                throw new IllegalStateException("An unsupported sortOrder was requested: " + sortField);
            }
            if (order == PageRequest.Order.DESC) {
                comparator = comparator.reversed();
            }
            result.add(comparator);
        }
        return result.stream()
                .reduce(Comparator::thenComparing)
                .orElse((a, b) -> 0); // default order
    }
}
