package org.hansken.plugin.extraction.runtime.grpc.server;

import static org.hansken.plugin.extraction.runtime.grpc.common.Checks.isMetaContext;
import static org.hansken.plugin.extraction.util.ArgChecks.argNotNegative;
import static org.hansken.plugin.extraction.util.ArgChecks.argNotNull;

import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;

import org.hansken.extraction.plugin.grpc.ExtractionPluginServiceGrpc.ExtractionPluginServiceImplBase;
import org.hansken.extraction.plugin.grpc.RpcDataContext;
import org.hansken.extraction.plugin.grpc.RpcPluginInfo;
import org.hansken.extraction.plugin.grpc.RpcStart;
import org.hansken.extraction.plugin.grpc.RpcTrace;
import org.hansken.plugin.extraction.api.BaseExtractionPlugin;
import org.hansken.plugin.extraction.api.DataContext;
import org.hansken.plugin.extraction.api.DeferredExtractionPlugin;
import org.hansken.plugin.extraction.api.DeferredMetaExtractionPlugin;
import org.hansken.plugin.extraction.api.ExtractionPlugin;
import org.hansken.plugin.extraction.api.MetaExtractionPlugin;
import org.hansken.plugin.extraction.runtime.grpc.common.Pack;
import org.hansken.plugin.extraction.runtime.grpc.common.Unpack;
import org.hansken.plugin.extraction.runtime.grpc.server.proxy.ExtractionContextProxy;
import org.hansken.plugin.extraction.runtime.grpc.server.proxy.GrpcFacade;
import org.hansken.plugin.extraction.runtime.grpc.server.proxy.TraceProxy;
import org.hansken.plugin.extraction.runtime.grpc.server.proxy.TraceSearcherProxy;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.protobuf.Any;
import com.google.protobuf.Empty;

import io.grpc.Status;
import io.grpc.stub.StreamObserver;

/**
 * Implements the actual gRPC service definition.
 * All RPC calls from the gRPC client to the gRPC server are handled here.
 *
 * @author Netherlands Forensic Institute
 */
public class ExtractionPluginServerService extends ExtractionPluginServiceImplBase {
    private static final Logger LOG = LoggerFactory.getLogger(ExtractionPluginServerService.class);

    private final ExecutorService _workers = Executors.newFixedThreadPool(8);
    private final Supplier<BaseExtractionPlugin> _plugin;
    private final int _maximumMessageSize;

    protected ExtractionPluginServerService(final Supplier<BaseExtractionPlugin> plugin, final int maximumMessageSize) {
        _plugin = argNotNull("plugin", plugin);
        _maximumMessageSize = argNotNegative("maximumMessageSize", maximumMessageSize);
    }

    @Override
    public void pluginInfo(final Empty request, final StreamObserver<RpcPluginInfo> responseObserver) {
        responseObserver.onNext(Pack.pluginInfo(_plugin.get().pluginInfo()));
        responseObserver.onCompleted();
    }

    @Override
    @SuppressWarnings("checkstyle:anoninnerlength")
    public StreamObserver<Any> process(final StreamObserver<Any> outgoingMessages) {
        // queue where request responses are stored until the tool consumes the response
        // since we only have at most blocking call per process(), we have to store at most one reply
        final BlockingQueue<Any> incomingMessages = new ArrayBlockingQueue<>(1);
        final GrpcFacade facade = new GrpcFacade(incomingMessages, outgoingMessages, _maximumMessageSize);

        return new StreamObserver<>() {

            private final AtomicBoolean _started = new AtomicBoolean();

            private final AtomicReference<Thread> _myThreadReference = new AtomicReference<>();

            @Override
            @SuppressWarnings("checkstyle:illegalcatch")
            public void onNext(final Any any) {
                try {
                    if (any.is(RpcStart.class)) {
                        if (_started.getAndSet(true)) {
                            // if we already received a start message in the same stream, something is wrong
                            facade.onError(Status.Code.FAILED_PRECONDITION, new IllegalStateException("processing of the trace has already been started"));
                            return;
                        }
                        _workers.execute(() -> {
                            _myThreadReference.set(Thread.currentThread());
                            final RpcStart start = Unpack.start(any);
                            process(start.getTrace(), start.getDataContext(), facade);
                        });
                        return;
                    }
                    facade.handleResponse(any);
                }
                catch (final Throwable t) {
                    facade.onError(Status.Code.CANCELLED, t);
                }
            }

            @SuppressWarnings("checkstyle:illegalcatch")
            private void process(final RpcTrace rpcTrace, final RpcDataContext rpcDataContext, final GrpcFacade facade) {
                final long start = System.nanoTime();
                String id = null;
                try {
                    final BaseExtractionPlugin plugin = _plugin.get();
                    try (TraceProxy trace = TraceProxy.fromRpc(rpcTrace, facade);
                         ExtractionContextProxy context = ExtractionContextProxy.fromRpc(rpcDataContext, trace.traceId(), facade)) {
                        id = trace.get("id");
                        logStartProcess(context, id);
                        process(plugin, rpcTrace, rpcDataContext, facade);
                    }

                    // call finish after processing is done, so everything is flushed
                    final long duration = System.nanoTime() - start;
                    facade.finishProcessing((double) duration / 1_000_000_000);

                    LOG.info("Finished processed trace with id: {}", id);
                }
                catch (final Throwable t) {
                    LOG.error("Error during processing trace with id: {}", id, t);
                    final long duration = System.nanoTime() - start;
                    facade.processPartialResultOrError(t, (double) duration / 1_000_000_000);
                }
            }

            private void logStartProcess(final DataContext context, final String id) {
                final String dataType = context.dataType();

                // meta context has no associated data stream
                if (isMetaContext(context)) {
                    LOG.info("Started processing trace with id: {}, data type: {}", id, dataType);
                }
                else {
                    final long size = context.data().size();
                    LOG.info("Started processing trace with id: {}, data type: {}, size: {}", id, dataType, size);
                }
            }

            private void process(final BaseExtractionPlugin plugin, final RpcTrace rpcTrace, final RpcDataContext rpcDataContext, final GrpcFacade facade) throws Exception {
                argNotNull("plugin", plugin);
                try (
                    TraceProxy trace = TraceProxy.fromRpc(rpcTrace, facade);
                    ExtractionContextProxy context = ExtractionContextProxy.fromRpc(rpcDataContext, trace.traceId(), facade)) {
                    switch (plugin) {
                        case final MetaExtractionPlugin metaExtractionPlugin ->
                            metaExtractionPlugin.process(trace);
                        case final ExtractionPlugin extractionPlugin ->
                            extractionPlugin.process(trace, context);
                        case final DeferredMetaExtractionPlugin deferredMetaExtractionPlugin ->
                            deferredMetaExtractionPlugin.process(trace, new TraceSearcherProxy(facade));
                        case final DeferredExtractionPlugin deferredExtractionPlugin ->
                             deferredExtractionPlugin.process(trace, context, new TraceSearcherProxy(facade));
                        default ->
                            throw new IllegalArgumentException("Provided plugin is not a known implementation of ExtractionPlugin or DeferredExtractionPlugin");
                    }
                }
            }

            @Override
            public void onError(final Throwable t) {
                LOG.error("Error received from stream: {}", t.getMessage(), t);
                // HANSKEN-13815: Improve and test process()-method stop on error
                if (_myThreadReference.get() != null) {
                    _myThreadReference.get().interrupt();
                }
            }

            @Override
            public void onCompleted() {
                facade.onCompleted();
            }
        };
    }
}
