package org.hansken.plugin.extraction.runtime.grpc.client;

import static java.util.Locale.ROOT;
import static java.util.concurrent.TimeUnit.SECONDS;

import static org.hansken.plugin.extraction.util.ArgChecks.argNotNull;

import java.io.IOException;
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

import org.hansken.extraction.plugin.grpc.RpcBatchUpdate;
import org.hansken.extraction.plugin.grpc.RpcBeginChild;
import org.hansken.extraction.plugin.grpc.RpcBeginDataStream;
import org.hansken.extraction.plugin.grpc.RpcEnrichTrace;
import org.hansken.extraction.plugin.grpc.RpcFinish;
import org.hansken.extraction.plugin.grpc.RpcFinishChild;
import org.hansken.extraction.plugin.grpc.RpcFinishDataStream;
import org.hansken.extraction.plugin.grpc.RpcNull;
import org.hansken.extraction.plugin.grpc.RpcPartialFinishWithError;
import org.hansken.extraction.plugin.grpc.RpcRead;
import org.hansken.extraction.plugin.grpc.RpcSearchRequest;
import org.hansken.extraction.plugin.grpc.RpcTrace;
import org.hansken.extraction.plugin.grpc.RpcWriteDataStream;
import org.hansken.plugin.extraction.api.SearchResult;
import org.hansken.plugin.extraction.api.Trace;
import org.hansken.plugin.extraction.api.TraceSearcher;
import org.hansken.plugin.extraction.runtime.grpc.client.api.ClientDataContext;
import org.hansken.plugin.extraction.runtime.grpc.client.api.ClientTrace;
import org.hansken.plugin.extraction.runtime.grpc.common.Pack;
import org.hansken.plugin.extraction.runtime.grpc.common.Unpack;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.protobuf.Any;

/**
 * Adapter that transforms incoming gRPC messages to clean API calls.
 *
 * @author Netherlands Forensic Institute
 */
@SuppressWarnings("checkstyle:illegalcatch")
public class ExtractionPluginGrpcAdapter {
    static final String CURRENT_PROCESS_TRACE_MARKER = "0";

    private static final Logger LOG = LoggerFactory.getLogger(ExtractionPluginGrpcAdapter.class);

    /**
     * This map tracks the in progress traces. This allows us to:
     * <ul>
     * <li>perform multiple actions (streaming writes for example) on a trace before flushing
     * <li>and to ensure that parent traces are flushed before their child-traces
     * </ul>
     */
    private final Map<String, TraceState> _inProgressTraces = new LinkedHashMap<>();
    private final ExecutorService _executor = Executors.newCachedThreadPool();
    private final ExtractionPluginDataReader _reader;

    // variable to keep error information in case the plugin failed due to an error
    private Throwable _error;
    private final TraceSearcher _searcher;

    private final TraceState _rootTrace;

    public ExtractionPluginGrpcAdapter(final ClientTrace trace, final ClientDataContext context, final ExtractionPluginDataReader reader) {
        this(trace, context, reader, null);
    }

    public ExtractionPluginGrpcAdapter(final ClientTrace trace, final ClientDataContext context, final ExtractionPluginDataReader reader, final TraceSearcher searcher) {
        argNotNull("trace", trace);
        argNotNull("context", context);
        _rootTrace = rootState(trace);
        stackPush(_rootTrace.id(), _rootTrace);
        _searcher = searcher;
        _reader = argNotNull("reader", reader);
    }

    /**
     * Processes a gRPC message, and calls the corresponding clean API.
     *
     * @param message Message received from client over gRPC
     * @return A response as result of the message execution, empty if there is no
     *         response.
     * @throws ExecutionException when an exception occurs during computation
     * @throws InterruptedException when a thread gets interrupted
     * @throws IOException when an I/O exception occurs
     */
    public Optional<Any> execute(final Any message) throws ExecutionException, InterruptedException, IOException {
        try {
            return executeUnsafe(message);
        }
        catch (final Throwable t) {
            shutdownTransferStates();
            throw t;
        }
    }

    private Optional<Any> executeUnsafe(final Any message) throws ExecutionException, InterruptedException, IOException {
        if (message.is(RpcBeginDataStream.class)) {
            final RpcBeginDataStream rpcBeginData = Unpack.any(message, RpcBeginDataStream.class);
            assertEqualId(rpcBeginData.getTraceId(), stackPeek(rpcBeginData.getTraceId()).id(), RpcBeginDataStream.class);
            currentState(rpcBeginData.getTraceId()).transfers().start(rpcBeginData.getDataType());
            return Optional.empty();
        }
        if (message.is(RpcWriteDataStream.class)) {
            final RpcWriteDataStream rpcWriteData = Unpack.any(message, RpcWriteDataStream.class);
            assertEqualId(rpcWriteData.getTraceId(), stackPeek(rpcWriteData.getTraceId()).id(), RpcWriteDataStream.class);
            currentState(rpcWriteData.getTraceId()).transfers().get(rpcWriteData.getDataType()).write(rpcWriteData.getData().toByteArray());
            return Optional.empty();
        }
        if (message.is(RpcFinishDataStream.class)) {
            final RpcFinishDataStream rpcFinishData = Unpack.any(message, RpcFinishDataStream.class);
            assertEqualId(rpcFinishData.getTraceId(), stackPeek(rpcFinishData.getTraceId()).id(), RpcFinishDataStream.class);
            currentState(rpcFinishData.getTraceId()).transfers().finish(rpcFinishData.getDataType());
            return Optional.empty();
        }
        if (message.is(RpcRead.class)) {
            final RpcRead read = Unpack.any(message, RpcRead.class);
            return Optional.of(Any.pack(Pack.primitive(executeRead(read))));
        }
        if (message.is(RpcSearchRequest.class)) {
            final RpcSearchRequest searchRequest = message.unpack(RpcSearchRequest.class);
            final int count = searchRequest.getCount();
            final String query = searchRequest.getQuery();
            final SearchResult result = _searcher.search(query, count);
            result.getTraces().forEach(_reader::markDataAvailable);
            return Optional.of(Any.pack(Pack.searchResult(result)));
        }
        // TODO HANSKEN-13962: write to trace within scope of a child
        if (message.is(RpcEnrichTrace.class)) {
            final RpcTrace rpcTrace = Unpack.any(message, RpcEnrichTrace.class).getTrace();
            final TraceState trace = stackPeek(rpcTrace.getId());
            assertEqualId(rpcTrace.getId(), trace.id(), RpcEnrichTrace.class);
            enrichTraceWith(trace.trace(), rpcTrace);
            return Optional.empty();
        }
        if (message.is(RpcBeginChild.class)) {
            final RpcBeginChild rpcBeginTrace = message.unpack(RpcBeginChild.class);
            final String childId = rpcBeginTrace.getId();
            final TraceState parent = stackPeek(getParentId(childId));
            final ClientTrace child = parent.newChild(rpcBeginTrace.getName());
            assertIsDirectParentIdOf(parent.id(), childId);
            stackPush(childId, childState(childId, child));
            return Optional.empty();
        }
        if (message.is(RpcFinishChild.class)) {
            final String id = Unpack.any(message, RpcFinishChild.class).getId();
            final TraceState trace = stackPop(id);
            assertEqualId(id, trace.id(), RpcFinishChild.class);
            final ClientTrace child = trace.trace();
            try {
                child.save();
            }
            catch (final Exception e) {
                throw new IllegalStateException("Unable to store child: " + child.name());
            }
            return Optional.empty();
        }
        if (message.is(RpcBatchUpdate.class)) {
            for (final Any any : Unpack.any(message, RpcBatchUpdate.class).getActionsList()) {
                execute(any);
            }
            return Optional.of(Any.pack(RpcNull.getDefaultInstance()));
        }
        if (message.is(RpcFinish.class)) {
            for (final Any any : Unpack.any(message, RpcFinish.class).getUpdate().getActionsList()) {
                execute(any);
            }
            if (_inProgressTraces.size() != 1) {
                throw new IllegalStateException("not all children have been finished");
            }
            return Optional.empty();
        }
        if (message.is(RpcPartialFinishWithError.class)) {
            for (final Any any : Unpack.any(message, RpcPartialFinishWithError.class).getActionsList()) {
                executePartial(any);
            }
            if (_inProgressTraces.size() != 1) {
                throw new IllegalStateException("not all children have been finished");
            }
            currentState(_rootTrace.id()).transfers().finishAll();
            return Optional.empty();
        }
        throw new IllegalStateException("unsupported type of message: " + message);
    }

    private Optional<Any> executePartial(final Any message) throws ExecutionException, InterruptedException, IOException {
        // TODO HANSKEN-14694: children don't seem to be flushed when error happens on client side?
        final Deque<String> traceStack = new ArrayDeque<>(_inProgressTraces.keySet());
        String traceStackId;
        if (message.is(RpcEnrichTrace.class)) {
            final String id = Unpack.any(message, RpcEnrichTrace.class).getTrace().getId();
            while (!traceStack.isEmpty() && !(traceStackId = traceStack.removeLast()).equals(id)) {
                stackPop(traceStackId).transfers().finishAll();
            }
        }
        if (message.is(RpcFinishChild.class)) {
            final String id = Unpack.any(message, RpcFinishChild.class).getId();
            while (!traceStack.isEmpty() && !(traceStackId = traceStack.removeLast()).equals(id)) {
                stackPop(traceStackId).transfers().finishAll();
            }
        }
        if (message.is(RpcBeginChild.class)) {
            final String id = Unpack.any(message, RpcBeginChild.class).getId();
            while (!traceStack.isEmpty() && !isDirectParentIdOf((traceStackId = traceStack.removeLast()), id)) {
                stackPop(traceStackId).transfers().finishAll();
            }
        }
        return execute(message);
    }

    private TraceState currentState(final String traceId) {
        return stackPeek(traceId);
    }

    private static void assertEqualId(final String actual, final String expected, final Class<?> messageClass) {
        if (!actual.equals(expected)) {
            throw new IllegalStateException(String.format(ROOT, "id is expected to be %s, but instead is %s. source:%s", expected, actual, messageClass.getSimpleName()));
        }
    }

    private void assertIsDirectParentIdOf(final String parentId, final String childId) {
        if (!isDirectParentIdOf(parentId, childId)) {
            throw new IllegalStateException(String.format(ROOT, "id %s is expected to be a direct child of %s, but it is not", childId, parentId));
        }
    }

    private static boolean isDirectParentIdOf(final String parentId, final String childId) {
        return childId.startsWith(parentId) && parentId.length() == childId.lastIndexOf("-");
    }

    private static void enrichTraceWith(final Trace child, final RpcTrace rpcChild) {
        // Handling the transformations first, or at least before the properties,
        // ensures that the plugin that makes the data stream can also set properties about that data stream.
        // The types should also be set before the properties are set.
        rpcChild.getTransformationsList().forEach(rpcTransformation ->
            child.setData(rpcTransformation.getDataType(), Unpack.transformations(rpcTransformation.getTransformationsList()))
        );
        rpcChild.getTypesList().forEach(child::addType);
        rpcChild.getPropertiesList().forEach(property ->
            child.set(property.getName(), Unpack.primitive(property.getValue()))
        );
        rpcChild.getTraceletsList().forEach(rpcTracelet ->
            child.addTracelet(Unpack.tracelet(rpcTracelet))
        );
    }

    private byte[] executeRead(final RpcRead read) throws IOException {
        final String traceUid = unpackTraceUid(read);
        return _reader.read(traceUid, read.getDataType(), read.getPosition(), read.getCount());
    }

    /**
     * Returns the traceUid field from the RpcRead message. If a plugin uses an sdk version < 0.4.9, this value is obtained from the traceId field,
     * The traceId field is not unique for traces from different images, so using this field is only allowed for the extraction trace.
     * Using this field for search traces may cause wrong data to return. See HANSKEN-15913.
     *
     * @param read RpcRead message
     * @return traceUid
     * @throws IllegalStateException if an older version of the sdk is used AND data is retrieved for a search trace
     */
    private String unpackTraceUid(final RpcRead read) {
        // If this field is not empty, an older sdk version is used to create this plugin (<0.4.9).
        if (!read.getTraceId().isEmpty()) {
            // Verify that this field refers to the extraction trace.
            if (read.getTraceId().equals(CURRENT_PROCESS_TRACE_MARKER)) {
                return read.getTraceId();
            }
            else {
                throw new IllegalStateException("The Extraction Plugin SDK used to create your plugin is outdated. This version still uses the traceId field instead of traceUid to identify trace data. Update to at least 0.4.9.");
            }
        }
        return read.getTraceUid();
    }

    private void shutdownTransferStates() {
        try {
            for (final String traceId : _inProgressTraces.keySet()) {
                final TraceState state = stackPeek(traceId);
                try {
                    state.transfers().finishAll();
                }
                catch (final Exception e) {
                    LOG.error("error while trying to shut down transfer for Trace with id {}", state.id(), e);
                }
            }
            _executor.shutdownNow();
            _executor.awaitTermination(5, SECONDS);
        }
        catch (final InterruptedException e) {
            Thread.currentThread().interrupt();
            throw new IllegalStateException(e);
        }
    }

    /**
     * Callback for errors. One can check whether errors have been set by using the error-method.
     * The error will be logged with a default logMessage
     *
     * @param error an error thrown somewhere else
     */
    public void error(final Throwable error) {
        error(error, "Error occurred during processing of the trace");
    }

    /**
     * Callback for errors. One can check whether errors have been set by using the error-method.
     *
     * @param error an error thrown somewhere else
     * @param logMessage the logged message
     */
    public void error(final Throwable error, final String logMessage) {
        // Note: Hansken will log this error as a warning, like any other tool that failed.
        LOG.debug(logMessage, error);
        _error = error;
    }

    /**
     * Returns the error set, if any.
     *
     * @return empty if no error was set, otherwise optional with the error set.
     */
    public Optional<Throwable> error() {
        return Optional.ofNullable(_error);
    }

    private TraceState rootState(final ClientTrace root) {
        return new TraceState(CURRENT_PROCESS_TRACE_MARKER, root);
    }

    private TraceState childState(final String id, final ClientTrace child) {
        return new TraceState(id, child);
    }

    private TraceState stackPeek(final String traceId) {
        return _inProgressTraces.getOrDefault(traceId, _rootTrace); // return root trace if stack is empty
    }

    private TraceState stackPop(final String traceId) {
        return _inProgressTraces.remove(traceId);
    }

    private void stackPush(final String traceId, final TraceState traceState) {
        _inProgressTraces.put(traceId, traceState);
    }

    private static String getParentId(final String traceId) {
        final int lastIndexOf = traceId.lastIndexOf('-');
        if (lastIndexOf == -1) {
            return CURRENT_PROCESS_TRACE_MARKER;
        }
        return traceId.substring(0, lastIndexOf);
    }

    /**
     * Utility class which contains a {@link Trace}, together with the id of that trace
     * and the data stream transfer state. The id is a {@link String} consisting of numbers
     * separated with a dash, each consecutive numbers representing a level deeper in the trace tree.
     * <p>
     * For example, if this trace has an id of {@code '0-0-1-2'}, the parent of this
     * trace has an id of {@code '0-0-1'} and the root trace has an id of {@code '0'}.
     */
    final class TraceState {

        private final String _id;
        private final ClientTrace _trace;
        private final DataStreamTransferStateManager _transfers;

        TraceState(final String id, final ClientTrace trace) {
            _id = argNotNull("id", id);
            _trace = argNotNull("trace", trace);
            _transfers = new DataStreamTransferStateManager(_trace, _executor);
        }

        String id() {
            return _id;
        }

        ClientTrace trace() {
            return _trace;
        }

        DataStreamTransferStateManager transfers() {
            return _transfers;
        }

        ClientTrace newChild(final String name) {
            return _trace.newChild(name);
        }
    }
}
