/* __________              _____                                                *\
** \______   \____   _____/ ____\____   ____    Copyright (c) 2017-2023 Ponfee  **
**  |     ___/  _ \ /    \   __\/ __ \_/ __ \   http://www.ponfee.cn            **
**  |    |  (  <_> )   |  \  | \  ___/\  ___/   Apache License Version 2.0      **
**  |____|   \____/|___|  /__|  \___  >\___  >  http://www.apache.org/licenses/ **
**                      \/          \/     \/                                   **
\*                                                                              */

package cn.ponfee.commons.concurrent;

import com.google.common.base.Stopwatch;
import org.apache.commons.collections4.CollectionUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/**
 * Multi Thread executor
 *
 * <p> {@code Thread#stop()} will occur "java.lang.ThreadDeath: null" if try...catch wrapped in Throwable
 *
 * @author Ponfee
 */
public class MultithreadExecutors {

    private static final Logger LOG = LoggerFactory.getLogger(MultithreadExecutors.class);

    /**
     * Exec async, usual use in test case
     *
     * @param parallelism the parallelism
     * @param command     the command
     * @param execSeconds the execSeconds
     * @param executor    the executor
     */
    public static void execute(int parallelism, Runnable command,
                               int execSeconds, Executor executor) {
        Stopwatch watch = Stopwatch.createStarted();
        AtomicBoolean flag = new AtomicBoolean(true);

        // CALLER_RUNS: caller run will be dead loop
        // caller thread will be loop exec command, can't to run the after code{flag.set(false)}
        // threadNumber > 32
        CompletableFuture<?>[] futures = IntStream
            .range(0, parallelism)
            .mapToObj(i -> (Runnable) () -> {
                while (flag.get() && !Thread.currentThread().isInterrupted()) {
                    command.run();
                }
            })
            .map(runnable -> CompletableFuture.runAsync(runnable, executor))
            .toArray(CompletableFuture[]::new);

        try {
            // parent thread sleep
            Thread.sleep(execSeconds * 1000L);
            flag.set(false);
            CompletableFuture.allOf(futures).join();
        } catch (InterruptedException e) {
            flag.set(false);
            Thread.currentThread().interrupt();
            throw new RuntimeException(e);
        } finally {
            LOG.info("multi thread exec async duration: {}", watch.stop());
        }
    }

    // -----------------------------------------------------------------execAsync

    /**
     * Exec async
     *
     * @param command     the command
     * @param parallelism the parallelism
     * @param executor    thread executor service
     */
    public static void execute(Runnable command, int parallelism, Executor executor) {
        Stopwatch watch = Stopwatch.createStarted();
        CompletableFuture[] futures = IntStream.range(0, parallelism)
                                               .mapToObj(i -> CompletableFuture.runAsync(command, executor))
                                               .toArray(CompletableFuture[]::new);

        CompletableFuture.allOf(futures).join();
        LOG.info("multi thread run async duration: {}", watch.stop());
    }

    // -----------------------------------------------------------------callAsync
    public static <U> List<U> execute(Supplier<U> supplier, int parallelism) {
        Stopwatch watch = Stopwatch.createStarted();
        List<U> result = IntStream.range(0, parallelism)
                                  .mapToObj(i -> CompletableFuture.supplyAsync(supplier))
                                  .collect(Collectors.toList())
                                  .stream()
                                  .map(CompletableFuture::join)
                                  .collect(Collectors.toList());
        LOG.info("multi thread call async duration: {}", watch.stop());
        return result;
    }

    // -----------------------------------------------------------------runAsync

    /**
     * Run async, action the T collection
     *
     * @param coll     the T collection
     * @param action   the T action
     * @param executor thread executor service
     */
    public static <T> void execute(Collection<T> coll, Consumer<T> action, Executor executor) {
        Stopwatch watch = Stopwatch.createStarted();
        coll.stream()
            .map(e -> CompletableFuture.runAsync(() -> action.accept(e), executor))
            .collect(Collectors.toList())
            .forEach(CompletableFuture::join);
        LOG.info("multi thread run async duration: {}", watch.stop());
    }

    // -----------------------------------------------------------------callAsync

    /**
     * Call async, mapped T to U
     *
     * @param coll     the T collection
     * @param mapper   the mapper of T to U
     * @param executor thread executor service
     * @return the U collection
     */
    public static <T, U> List<U> execute(Collection<T> coll, Function<T, U> mapper, Executor executor) {
        Stopwatch watch = Stopwatch.createStarted();
        List<U> result = coll.stream()
                             .map(e -> CompletableFuture.supplyAsync(() -> mapper.apply(e), executor))
                             .collect(Collectors.toList())
                             .stream()
                             .map(CompletableFuture::join)
                             .collect(Collectors.toList());
        LOG.info("multi thread call async duration: {}", watch.stop());
        return result;
    }

    /**
     * 根据数据（任务）数量来判断是否主线程执行还是提交到线程池执行
     *
     * @param data              the data
     * @param action            the action
     * @param dataSizeThreshold the dataSizeThreshold
     * @param executor          the executor
     * @param <T>               data element type
     * @param <R>               result element type
     * @return list for action result
     */
    public static <T, R> List<R> execute(Collection<T> data, Function<T, R> action,
                                         int dataSizeThreshold, Executor executor) {
        if (CollectionUtils.isEmpty(data)) {
            return Collections.emptyList();
        }
        if (dataSizeThreshold < 1 || data.size() < dataSizeThreshold) {
            return data.stream().map(action).collect(Collectors.toList());
        }

        CompletionService<R> service = new ExecutorCompletionService<>(executor);
        data.forEach(e -> service.submit(() -> action.apply(e)));
        return join(service, data.size());
    }

    /**
     * 根据数据（任务）数量来判断是否主线程执行还是提交到线程池执行
     *
     * @param data              the data
     * @param action            the action
     * @param dataSizeThreshold the dataSizeThreshold
     * @param executor          the executor
     * @param <T>               data element type
     */
    public static <T> void execute(Collection<T> data, Consumer<T> action,
                                   int dataSizeThreshold, Executor executor) {
        if (CollectionUtils.isEmpty(data)) {
            return;
        }
        if (dataSizeThreshold < 1 || data.size() < dataSizeThreshold) {
            data.forEach(action);
            return;
        }

        CompletionService<Void> service = new ExecutorCompletionService<>(executor);
        data.forEach(e -> service.submit(() -> action.accept(e), null));
        joinDiscard(service, data.size());
    }

    // -----------------------------------------------------------------join
    public static <T> List<T> join(CompletionService<T> service, int count) {
        List<T> result = new ArrayList<>(count);
        join(service, count, result::add);
        return result;
    }

    public static <T> void joinDiscard(CompletionService<T> service, int count) {
        join(service, count, t -> { });
    }

    public static <T> void join(CompletionService<T> service, int count, Consumer<T> accept) {
        try {
            while (count-- > 0) {
                // block until a task done
                Future<T> future = service.take();
                accept.accept(future.get());
            }
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            throw new RuntimeException(e);
        } catch (ExecutionException e) {
            throw new RuntimeException(e);
        }
    }

}
