package org.hansken.plugin.extraction.runtime.grpc.client;

import static org.hansken.plugin.extraction.runtime.grpc.client.ExtractionPluginGrpcAdapter.CURRENT_PROCESS_TRACE_MARKER;

import java.util.concurrent.CountDownLatch;

import org.hansken.extraction.plugin.grpc.RpcFinish;
import org.hansken.extraction.plugin.grpc.RpcPartialFinishWithError;
import org.hansken.extraction.plugin.grpc.RpcStart;
import org.hansken.plugin.extraction.api.DataContext;
import org.hansken.plugin.extraction.api.Trace;
import org.hansken.plugin.extraction.runtime.grpc.common.Pack;
import org.hansken.plugin.extraction.runtime.grpc.common.Unpack;

import com.google.protobuf.Any;

import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import io.grpc.stub.StreamObserver;

/**
 * Handler for orchestrating the processing of a {@link Trace trace} over gRPC, which itself
 * is an observer for a stream in order to receive messages from the server.
 */
// TODO: HANSKEN-13583 - document FINISH&Stream closing behavior (when all test cases are properly implement we know that to document)
public class ProtocolHandler implements StreamObserver<Any> {
    private final ExtractionPluginGrpcAdapter _adapter;
    private final CountDownLatch _finishLatch;
    private final ReplyStream _replyStream;

    // flag indicating that the START message has been sent to the server
    private volatile boolean _started;
    // flag indicating that the FINISH message has been received from the server
    private volatile boolean _finishReceived;

    /**
     * Initialize a handler for processing a {@link Trace trace} over gRPC.
     *
     * @param replyStream the outbound message stream
     * @param adapter the adapter which can execute the received messages
     */
    public ProtocolHandler(final ReplyStream replyStream, final ExtractionPluginGrpcAdapter adapter) {
        _adapter = adapter;
        _replyStream = replyStream;
        _finishLatch = new CountDownLatch(1);
        _started = false;
    }

    /**
     * Sends the START message sent over the communication streams. This message is defined in the gRPC
     * protocol. When the ExtractionPlugin receives this message it actually starts the process method.
     *
     * @param trace the trace to process
     * @param context the context to process
     */
    @SuppressWarnings("checkstyle:illegalcatch")
    public void start(final Trace trace, final DataContext context) {
        try {
            _started = true;
            _replyStream.reply(Any.pack(
                RpcStart.newBuilder()
                    .setTrace(Pack.trace(CURRENT_PROCESS_TRACE_MARKER, trace))
                    .setDataContext(Pack.metaOfDataContext(context))
                    .build()));
        }
        catch (final Throwable t) {
            // oops, we got an exception, first inform the server, this will also disconnect

            // note that are typically serialization exceptions. gRPC exceptions (such as a remote server
            // not being available) are passed to the onError() handler
            setErrorFromInternal(Status.Code.CANCELLED, t);

            // rethrow the exception to indicate that we never started, and so that await() does
            // not have to be called
            throw t;
        }
    }

    /**
     * Blocks and waits for the processing of the {@link Trace trace} to be finished,
     * whether or not due to an error.
     *
     * @throws RuntimeException any exception that occured during remote processing
     */
    public void await() {
        try {
            _finishLatch.await();
        }
        catch (final InterruptedException e) {
            try {
                setErrorFromInternal(Status.Code.CANCELLED, e);
            }
            finally {
                Thread.currentThread().interrupt();
            }
        }
        finally {
            handleFinishErrors();
        }
    }

    @Override
    @SuppressWarnings("checkstyle:illegalcatch")
    public void onNext(final Any message) {
        try {
            if (!_started) {
                _finishLatch.countDown();
                throw new IllegalStateException("received a message from server, but processing of the trace has not been started yet");
            }
            _adapter.execute(message).ifPresent(_replyStream::reply);
            if (message.is(RpcFinish.class) || message.is(RpcPartialFinishWithError.class)) {
                // register the receiving of the finish message
                _finishReceived = true;
                // the extraction plugin informed us the trace processing is finished
                // we have handled the message, but here we'll have signal stream completion
                _replyStream.onCompleted();
            }
            setPartialFinishError(message);
        }
        catch (final InterruptedException e) {
            Thread.currentThread().interrupt();
            _adapter.error(e);
            _replyStream.onError(Pack.asStatusRuntimeException(Status.Code.ABORTED, e));
        }
        catch (final Throwable t) {
            _adapter.error(t);
            _replyStream.onError(Pack.asStatusRuntimeException(Status.Code.ABORTED, t));
        }
    }

    /**
     * Callable that processes errors.
     * <p>
     * This method will _not_ signal the remote server that there was an exception, and thus
     * this method will not disconnect from the server. Please use setErrorFromInternal() if
     * the error cause was from an internal routine and a disconnect should be triggered.
     * <p>
     * Since we are in error state, no new incoming messages are expected. Therefore
     * onError() lowers a latch so that await() finishes. The throwable passed to this
     * method will bubble back to the caller of await().
     *
     * @param t error that occured.
     */
    @Override
    public void onError(final Throwable t) {
        try {
            _adapter.error(t);
        }
        finally {
            _finishLatch.countDown();
        }
    }

    @Override
    public void onCompleted() {
        try {
            if (!_finishReceived) {
                onError(new IllegalStateException("finish message has not been received"));
            }
        }
        finally {
            // stream is completed, let the main thread continue
            _finishLatch.countDown();
        }
    }

    private void setErrorFromInternal(final Status.Code code, final Throwable t) {
        try {
            // try to inform the remote that we had an internal error
            // and set the communication channel in error state
            _replyStream.onError(Pack.asStatusRuntimeException(code, t));
        }
        finally {
            // further internal error processing
            onError(t);
        }
    }

    /**
     * This method creates a {@link StatusRuntimeException} if the input message is of type {@link RpcPartialFinishWithError}.
     *
     * @param message the text message description of the error
     */
    private void setPartialFinishError(final Any message) {
        if (message.is(RpcPartialFinishWithError.class)) {
            final RpcPartialFinishWithError partialFinish = Unpack.any(message, RpcPartialFinishWithError.class);
            final Status.Code code = Status.Code.valueOf(partialFinish.getStatusCode());

            _adapter.error(ExtractionPluginException.getInstanceWithoutStacktrace(partialFinish.getErrorDescription()),
                "PartialFinishError. An exception was thrown by the Extraction Plugin."
            );
        }
    }

    /**
     * This call rethrows any exception(s) that caused an extraction to be interrupted halfway through.
     */
    private void handleFinishErrors() {
        _adapter.error().ifPresent(error -> {
            if (error instanceof RuntimeException) {
                throw (RuntimeException) error;
            }
            else {
                throw new IllegalStateException(error);
            }
        });
    }
}
