package org.hansken.plugin.extraction.runtime.grpc.client;

import static org.hansken.extraction.plugin.grpc.ExtractionPluginServiceGrpc.newBlockingStub;
import static org.hansken.extraction.plugin.grpc.ExtractionPluginServiceGrpc.newStub;
import static org.hansken.plugin.extraction.util.ArgChecks.argNotNull;

import static io.grpc.internal.GrpcUtil.authorityFromHostAndPort;
import static io.grpc.stub.MetadataUtils.attachHeaders;

import java.io.IOException;
import java.util.concurrent.TimeUnit;

import org.hansken.extraction.plugin.grpc.ExtractionPluginServiceGrpc.ExtractionPluginServiceBlockingStub;
import org.hansken.extraction.plugin.grpc.ExtractionPluginServiceGrpc.ExtractionPluginServiceStub;
import org.hansken.extraction.plugin.grpc.RpcPluginInfo;
import org.hansken.plugin.extraction.api.PluginInfo;
import org.hansken.plugin.extraction.api.TraceSearcher;
import org.hansken.plugin.extraction.runtime.grpc.client.api.ClientDataContext;
import org.hansken.plugin.extraction.runtime.grpc.client.api.ClientTrace;
import org.hansken.plugin.extraction.runtime.grpc.client.api.RemoteExtractionPlugin;
import org.hansken.plugin.extraction.runtime.grpc.common.Unpack;
import org.hansken.plugin.extraction.runtime.grpc.common.VersionUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.protobuf.Any;
import com.google.protobuf.Empty;

import io.grpc.ConnectivityState;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.Metadata;
import io.grpc.StatusRuntimeException;
import io.grpc.stub.StreamObserver;

/**
 * Client to connect to a running Extraction Plugin (server).
 * <p>
 * This client acts as a proxy between the clean API and an extraction
 * plugin implementation that is served by a remote gRPC server.
 *
 * @author Netherlands Forensic Institute
 */
public class ExtractionPluginClient implements RemoteExtractionPlugin, AutoCloseable {
    private static final Logger LOG = LoggerFactory.getLogger(ExtractionPluginClient.class);

    private static final Empty EMPTY = Empty.getDefaultInstance();
    private static final int MAX_MESSAGE_SIZE = 64 * 1024 * 1024;

    // communication channel and streams from client to server
    private final ManagedChannel _channel;
    private final ExtractionPluginServiceBlockingStub _blockingPluginStub;
    private final ExtractionPluginServiceStub _asyncPluginStub;
    private final String _target;

    private RpcPluginInfo _pluginInfo;

    /**
     * Constructor for the Extraction Plugin client.
     * <p>
     * During construction of this client object no actual connection is
     * created, so if the server is not available on the provided host and port,
     * no exception is thrown. Instead, an exception will be thrown each time
     * a new info or process-method is invoked. In case the client has to fail early
     * it should invoke pluginInfo() directly after constructing the client.
     * <p>
     * There is no retry policy configured, for this, use
     * {@link ExtractionPluginClient#ExtractionPluginClient(String, int, RetryPolicy)}.
     *
     * @param host Host where the extraction plugin server is running
     * @param port Port where the extraction plugin server is running on
     */
    public ExtractionPluginClient(final String host, final int port) {
        this(authorityFromHostAndPort(host, port));
    }

    /**
     * See {@link ExtractionPluginClient#ExtractionPluginClient(String, int)}, but with a {@link RetryPolicy}.
     *
     * @param host host where the extraction plugin server is running
     * @param port port where the extraction plugin server is running
     * @param retryPolicy the retry configuration policy to use, or {@code null} if no policy should be used
     */
    public ExtractionPluginClient(final String host, final int port, final RetryPolicy retryPolicy) {
        this(authorityFromHostAndPort(host, port), retryPolicy);
    }

    /**
     * Same as {@link ExtractionPluginClient#ExtractionPluginClient(String, int)}, but with a
     * {@link ManagedChannelBuilder#forTarget(String) target} string instead.
     *
     * @param target endpoint where the extraction plugin server is running
     */
    public ExtractionPluginClient(final String target) {
        this(target, null);
    }

    /**
     * Same as {@link ExtractionPluginClient#ExtractionPluginClient(String, RetryPolicy)}, but with a
     * {@link ManagedChannelBuilder#forTarget(String) target} string instead.
     *
     * @param target endpoint where the extraction plugin server is running
     * @param retryPolicy the retry configuration policy to use, or {@code null} if no policy should be used
     */
    public ExtractionPluginClient(final String target, final RetryPolicy retryPolicy) {
        this(target, retryPolicy, "");
    }

    /**
     * Same as {@link ExtractionPluginClient#ExtractionPluginClient(String, RetryPolicy, String)}, but with a
     * pluginId parameter for routing.
     *
     * @param target endpoint where the extraction plugin server is running
     * @param retryPolicy the retry configuration policy to use, or {@code null} if no policy should be used
     * @param pluginId the id of the plugin, which is set in the header, for routing purposes
     */
    public ExtractionPluginClient(final String target, final RetryPolicy retryPolicy, final String pluginId) {
        argNotNull("pluginId", pluginId);
        _target = argNotNull("target", target);

        final ManagedChannelBuilder<?> builder = ManagedChannelBuilder
            .forTarget(target)
            .usePlaintext()
            .maxInboundMessageSize(MAX_MESSAGE_SIZE);

        _channel = (retryPolicy == null)
            ? builder.build()
            : builder.defaultServiceConfig(retryPolicy.toMethodConfigMap())
                .enableRetry()
                .maxRetryAttempts(retryPolicy.maxAttempts())
                .build();

        final Metadata header = new Metadata();
        header.put(Metadata.Key.of("pluginId", Metadata.ASCII_STRING_MARSHALLER), pluginId);

        _blockingPluginStub = attachHeaders(newBlockingStub(_channel), header);
        _asyncPluginStub = attachHeaders(newStub(_channel), header);
    }

    @Override
    public boolean isCompatible() {
        final String remotePluginVersion = Unpack.pluginApiVersion(getRpcPluginInfo());
        return VersionUtil.isCompatible(remotePluginVersion);
    }

    @Override
    public PluginInfo pluginInfo() {
        return Unpack.pluginInfo(getRpcPluginInfo());
    }

    private RpcPluginInfo getRpcPluginInfo() {
        // cache pluginInfo for future calls
        if (_pluginInfo == null) {
            _pluginInfo = _blockingPluginStub.pluginInfo(EMPTY);
        }
        return _pluginInfo;
    }

    // default visibility for access in test
    final void process(final ClientTrace trace, final ClientDataContext dataContext, final TraceSearcher traceSearcher, final ReplyStream replyStream) {
        // first, set up a bi-directional communication stream with the ExtractionPlugin server
        // a helper object ReplySender is used to make sure that both incoming and outgoing streams are
        // available in the ProtocolHandler.
        final ProtocolHandler protocolHandler = handler(trace, dataContext, traceSearcher, replyStream);
        final StreamObserver<Any> responseSender = _asyncPluginStub.process(protocolHandler);
        replyStream.init(responseSender);

        // we're all set up! Now send a start message to trigger the process() execution
        protocolHandler.start(trace, dataContext);

        // gRPC messages are exchanged asynchronously on gRPC threads we'll wait for the trace process to finish
        protocolHandler.await();
    }

    @Override
    public void process(final ClientTrace trace, final ClientDataContext dataContext) throws IOException {
        try {
            process(trace, dataContext, null, new ReplyStream());
        } catch (final StatusRuntimeException e) {
            LOG.error("Got a gRPC StatusRuntimeException (status: " + e.getStatus() + "), logging this here, since the unwrapped exception is rethrown from here", e);
            // rethrow the unwrapped exception for better understandable error messages client-side
            throw unwrap(e);
        }
    }

    @Override
    public void processDeferred(final ClientTrace trace, final ClientDataContext dataContext, final TraceSearcher searcher) {
        process(trace, dataContext, searcher, new ReplyStream());
    }

    /**
     * Gets the client's connection target, a String which consists of a host and port. For example: localhost:8999.
     *
     * @return a host and port combined into an authority string
     */
    public String getTarget() {
        return _target;
    }

    /**
     * Returns a handler for orchestrating the processing of a trace over gRPC.
     * <p>
     * This method is exposed as protected method so we can hook into this object using unit tests.
     *
     * @param trace the trace to process
     * @param dataContext the data context to process
     * @param searcher the trace searcher
     * @param stream the outbound communication stream
     * @return a gRPC handler
     */
    protected ProtocolHandler handler(final ClientTrace trace, final ClientDataContext dataContext,
                                      final TraceSearcher searcher, final ReplyStream stream) {
        return new ProtocolHandler(stream, adapter(trace, dataContext, searcher));
    }

    /**
     * Gets the current connectivity state. Note the result may soon become outdated.
     *
     * @param requestConnection if {@code true}, the channel will try to make a connection if it is
     *     currently IDLE
     * @return the state of the connection
     */
    ConnectivityState getState(final boolean requestConnection) {
        return _channel.getState(requestConnection);
    }

    /**
     * Returns an adapter that translates gRPC protocol messages to clean API calls.
     * <p>
     * This method is exposed as protected method so we can hook into this object using unit tests.
     *
     * @param trace the trace on which to make the API calls
     * @param dataContext the data context on which to make the API calls
     * @param searcher the searcher on which to make the API calls
     * @return gRPC adapter object
     */
    protected ExtractionPluginGrpcAdapter adapter(final ClientTrace trace, final ClientDataContext dataContext, final TraceSearcher searcher) {
        return new ExtractionPluginGrpcAdapter(trace, dataContext, new ExtractionPluginDataReader(dataContext), searcher);
    }

    @Override
    public void close() throws InterruptedException {
        _channel.shutdownNow().awaitTermination(5, TimeUnit.SECONDS);
    }

    /**
     * Improve exceptions thrown by gRPC for extraction plugin users, by unpacking io.grpc.StatusRuntimeException
     * and by (re)throwing the original exception.
     *
     * @param e original exception
     * @return a new IllegalStateException if the unwrapped exception is not a IOException or RuntimeException
     * @throws IOException if the unwrapped exception is a IOException
     */
    private RuntimeException unwrap(final StatusRuntimeException e) throws IOException {
        if (e.getCause() == null) {
            // nothing to unwrap, just rethrow the original exception
            throw e;
        } else if (e.getCause() instanceof StatusRuntimeException) {
            return unwrap((StatusRuntimeException) e.getCause());
        } else if (e.getCause() instanceof RuntimeException) {
            throw (RuntimeException) e.getCause();
        } else if (e.getCause() instanceof IOException) {
            throw (IOException) e.getCause();
        } else if (e.getCause() instanceof Error) {
            throw (Error) e.getCause();
        } else {
            // unexpected, since plugin.process() does not throw checked exceptions other than IOException
            // however keep a safeguard in case checked exceptions make it here.
            return new IllegalStateException(e.getCause());
        }
    }
}
