package org.hansken.plugin.extraction.runtime.grpc.client;

import static java.lang.String.format;
import static java.util.Locale.ROOT;

import static org.hansken.extraction.plugin.grpc.ExtractionPluginServiceGrpc.newBlockingStub;
import static org.hansken.extraction.plugin.grpc.ExtractionPluginServiceGrpc.newStub;
import static org.hansken.plugin.extraction.runtime.grpc.common.Pack.transformerArgumentMap;
import static org.hansken.plugin.extraction.runtime.grpc.common.Pack.transformerRequest;
import static org.hansken.plugin.extraction.runtime.grpc.common.Unpack.transformerArgument;
import static org.hansken.plugin.extraction.runtime.grpc.common.Unpack.transformerResponse;
import static org.hansken.plugin.extraction.util.ArgChecks.argNotNull;

import static io.grpc.internal.GrpcUtil.authorityFromHostAndPort;
import static io.grpc.stub.MetadataUtils.newAttachHeadersInterceptor;

import java.io.IOException;
import java.time.Instant;
import java.time.ZoneId;
import java.time.ZonedDateTime;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;

import org.hansken.extraction.plugin.grpc.ExtractionPluginServiceGrpc.ExtractionPluginServiceBlockingStub;
import org.hansken.extraction.plugin.grpc.ExtractionPluginServiceGrpc.ExtractionPluginServiceStub;
import org.hansken.extraction.plugin.grpc.RpcPluginInfo;
import org.hansken.extraction.plugin.grpc.RpcTransformerRequest;
import org.hansken.extraction.plugin.grpc.RpcTransformerResponse;
import org.hansken.plugin.extraction.api.LatLong;
import org.hansken.plugin.extraction.api.PluginInfo;
import org.hansken.plugin.extraction.api.TraceSearcher;
import org.hansken.plugin.extraction.api.TransformerArgument;
import org.hansken.plugin.extraction.api.TransformerLabel;
import org.hansken.plugin.extraction.api.Vector;
import org.hansken.plugin.extraction.runtime.grpc.client.api.ClientDataContext;
import org.hansken.plugin.extraction.runtime.grpc.client.api.ClientTrace;
import org.hansken.plugin.extraction.runtime.grpc.client.api.RemoteExtractionPlugin;
import org.hansken.plugin.extraction.runtime.grpc.common.Unpack;
import org.hansken.plugin.extraction.runtime.grpc.common.VersionUtil;
import org.hansken.plugin.extraction.runtime.json.PluginModule;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
import com.google.common.collect.Sets;

import com.google.protobuf.Any;
import com.google.protobuf.Empty;

import io.grpc.ConnectivityState;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.Metadata;
import io.grpc.StatusRuntimeException;
import io.grpc.stub.StreamObserver;

/**
 * Client to connect to a running Extraction Plugin (server).
 * <p>
 * This client acts as a proxy between the clean API and an extraction
 * plugin implementation that is served by a remote gRPC server.
 *
 * @author Netherlands Forensic Institute
 */
public class ExtractionPluginClient implements RemoteExtractionPlugin, AutoCloseable {
    // The JavaTimeModule and PluginModule are added to deserialize/serialize ZonedDateTime.class, Vector.class, and LatLong.class
    private static final ObjectMapper MAPPER = new ObjectMapper().registerModule(new PluginModule()).registerModule(new JavaTimeModule());

    private static final Logger LOG = LoggerFactory.getLogger(ExtractionPluginClient.class);
    private static final Empty EMPTY = Empty.getDefaultInstance();
    private static final int MAX_MESSAGE_SIZE = 64 * 1024 * 1024;

    /** A map of the supported transformer types and their mapping to generic type (based on the Hansken trace model). */
    private static final Map<String, Class<?>> SUPPORTED_TRANSFORMER_TYPES = Map.of(
        "boolean", Boolean.class,
        "binary", byte[].class,
        "integer", Long.class,
        "real", Double.class,
        "string", String.class,
        "vector", Vector.class,
        "latLong", LatLong.class,
        "date", ZonedDateTime.class,
        "list", List.class,
        "map", Map.class
    );


    // communication channel and streams from client to server
    private final ManagedChannel _channel;
    private final ExtractionPluginServiceBlockingStub _blockingPluginStub;
    private final ExtractionPluginServiceStub _asyncPluginStub;
    private final String _target;

    private RpcPluginInfo _pluginInfo;

    /**
     * Constructor for the Extraction Plugin client.
     * <p>
     * During construction of this client object no actual connection is
     * created, so if the server is not available on the provided host and port,
     * no exception is thrown. Instead, an exception will be thrown each time
     * a new info or process-method is invoked. In case the client has to fail early
     * it should invoke pluginInfo() directly after constructing the client.
     * <p>
     * There is no retry policy configured, for this, use
     * {@link ExtractionPluginClient#ExtractionPluginClient(String, int, RetryPolicy)}.
     *
     * @param host Host where the extraction plugin server is running
     * @param port Port where the extraction plugin server is running on
     */
    public ExtractionPluginClient(final String host, final int port) {
        this(authorityFromHostAndPort(host, port));
    }

    /**
     * See {@link ExtractionPluginClient#ExtractionPluginClient(String, int)}, but with a {@link RetryPolicy}.
     *
     * @param host host where the extraction plugin server is running
     * @param port port where the extraction plugin server is running
     * @param retryPolicy the retry configuration policy to use, or {@code null} if no policy should be used
     */
    public ExtractionPluginClient(final String host, final int port, final RetryPolicy retryPolicy) {
        this(authorityFromHostAndPort(host, port), retryPolicy);
    }

    /**
     * Same as {@link ExtractionPluginClient#ExtractionPluginClient(String, int)}, but with a
     * {@link ManagedChannelBuilder#forTarget(String) target} string instead.
     *
     * @param target endpoint where the extraction plugin server is running
     */
    public ExtractionPluginClient(final String target) {
        this(target, null);
    }

    /**
     * Same as {@link ExtractionPluginClient#ExtractionPluginClient(String, RetryPolicy)}, but with a
     * {@link ManagedChannelBuilder#forTarget(String) target} string instead.
     *
     * @param target endpoint where the extraction plugin server is running
     * @param retryPolicy the retry configuration policy to use, or {@code null} if no policy should be used
     */
    public ExtractionPluginClient(final String target, final RetryPolicy retryPolicy) {
        this(target, retryPolicy, "");
    }

    /**
     * Same as {@link ExtractionPluginClient#ExtractionPluginClient(String, RetryPolicy)}, but with a
     * pluginId parameter for routing.
     *
     * @param target endpoint where the extraction plugin server is running
     * @param retryPolicy the retry configuration policy to use, or {@code null} if no policy should be used
     * @param pluginId the id of the plugin, which is set in the header, for routing purposes
     */
    public ExtractionPluginClient(final String target, final RetryPolicy retryPolicy, final String pluginId) {
        argNotNull("pluginId", pluginId);
        _target = argNotNull("target", target);

        final ManagedChannelBuilder<?> builder = ManagedChannelBuilder
            .forTarget(target)
            .usePlaintext()
            .maxInboundMessageSize(MAX_MESSAGE_SIZE);

        _channel = (retryPolicy == null)
            ? builder.build()
            : builder.defaultServiceConfig(retryPolicy.toMethodConfigMap())
                .enableRetry()
                .maxRetryAttempts(retryPolicy.maxAttempts())
                .build();

        final Metadata header = new Metadata();
        header.put(Metadata.Key.of("pluginId", Metadata.ASCII_STRING_MARSHALLER), pluginId);

        _blockingPluginStub = newBlockingStub(_channel).withInterceptors(newAttachHeadersInterceptor(header));
        _asyncPluginStub = newStub(_channel).withInterceptors(newAttachHeadersInterceptor(header));
    }

    @Override
    public boolean isCompatible() {
        final String remotePluginVersion = getPluginApiVersion();
        return VersionUtil.isCompatible(remotePluginVersion);
    }

    /**
     * Returns the API version of the Remote Extraction Plugin.
     *
     * @return API version
     */
    public String getPluginApiVersion() {
        return Unpack.pluginApiVersion(getRpcPluginInfo());
    }

    @Override
    public PluginInfo pluginInfo() {
        return Unpack.pluginInfo(getRpcPluginInfo());
    }

    private RpcPluginInfo getRpcPluginInfo() {
        // cache pluginInfo for future calls
        if (_pluginInfo == null) {
            _pluginInfo = _blockingPluginStub.pluginInfo(EMPTY);
        }
        return _pluginInfo;
    }

    // default visibility for access in test
    final void process(final ClientTrace trace, final ClientDataContext dataContext, final TraceSearcher traceSearcher, final ReplyStream replyStream) {
        // first, set up a bi-directional communication stream with the ExtractionPlugin server
        // a helper object ReplySender is used to make sure that both incoming and outgoing streams are
        // available in the ProtocolHandler.
        final ProtocolHandler protocolHandler = handler(trace, dataContext, traceSearcher, replyStream);
        final StreamObserver<Any> responseSender = _asyncPluginStub.process(protocolHandler);
        replyStream.init(responseSender);

        // we're all set up! Now send a start message to trigger the process() execution
        protocolHandler.start(trace, dataContext);

        // gRPC messages are exchanged asynchronously on gRPC threads we'll wait for the trace process to finish
        protocolHandler.await();
    }

    @Override
    public void process(final ClientTrace trace, final ClientDataContext dataContext) throws IOException {
        withUnwrappedExceptions(() -> process(trace, dataContext, null, new ReplyStream()));
    }

    @Override
    public void processDeferred(final ClientTrace trace, final ClientDataContext dataContext, final TraceSearcher searcher) throws IOException {
        withUnwrappedExceptions(() -> process(trace, dataContext, searcher, new ReplyStream()));
    }

    /**
     * Validates if a transformer request given a transformer label and specific arguments is correct.
     * @param label The transformer label specifying the transformer method that is to be called.
     * @param arguments The actual arguments that the transformer method is to be invoked with.
     * @throws IllegalArgumentException Any validation errors describing how the current label and arguments are not compliant.
     */
    public void validateTransformerRequest(final TransformerLabel label, final Map<String, TransformerArgument> arguments) throws IllegalArgumentException {

        // Validate if arguments are provided for each parameter specified in the transformer label.
        if (!label.getParameters().keySet().equals(arguments.keySet())) {
            throw new IllegalArgumentException(
                "The provided arguments do not match the parameters in the TransformerLabel. " +
                "Required: " + label.getParameters().keySet() + ", Provided: " + arguments.keySet());
        }

        // Validate if each argument matches with the right label.
        for (final Map.Entry<String, String> parameter : label.getParameters().entrySet()) {
            final TransformerArgument argument = arguments.get(parameter.getKey());
            if (!SUPPORTED_TRANSFORMER_TYPES.containsKey(parameter.getValue())) {
                throw new IllegalStateException(format(ROOT, "TransformerLabel for %s contains an illegal value: %s",
                    label.getMethodName(), parameter.getValue()));
            }

            final Class<?> expectedType = SUPPORTED_TRANSFORMER_TYPES.get(parameter.getValue());
            if (!expectedType.isInstance(argument.getArgument())) {
                throw new IllegalStateException(String.format(ROOT, "Provided argument is not of type provided in Transformer. " +
                    format(ROOT, "Transformer Type: %s, Provided: %s, Expected: %s", parameter.getValue(), argument.getArgument(), expectedType)));
            }
        }
    }

    @Override
    public TransformerArgument transform(final TransformerLabel label, final Map<String, String> stringMap) {

        final Map<String, TransformerArgument> arguments = deserializeJsonArguments(label, stringMap);

        // Validate if the transformer requests and arguments are correctly formed.
        validateTransformerRequest(label, arguments);

        // Construct the RPC request by setting the transformer method to call and the arguments.
        final RpcTransformerRequest request = transformerRequest(label, transformerArgumentMap(arguments));

        // Calls the transform() method of the extraction plugin server using gRPC giving it our specific request.
        // The transform method is responsible for relaying the request to the right transformer method in the plugin.
        final RpcTransformerResponse response = invokeTransformer(request);

        // Retrieve the returned result in the response.
        return new TransformerArgument(transformerArgument(transformerResponse(response)));
    }

    public Map<String, TransformerArgument> deserializeJsonArguments(final TransformerLabel label, final Map<String, String> arguments) {
        final Map<String, TransformerArgument> map = new HashMap<>();
        for (final String key : Sets.intersection(arguments.keySet(), label.getParameters().keySet())) {
            final String type = label.getParameters().get(key);
            final String value = arguments.get(key);
            map.put(key, new TransformerArgument(createTransformerArgument(type, value)));
        }
        return map;
    }

    public static Object createTransformerArgument(final String type, final String value) {
        if (!value.isBlank()) {
            try {
                return jsonConverter(type, value);
            }
            catch (final JsonProcessingException e) {
                throw new IllegalArgumentException(e);
            }
        }
        throw new IllegalArgumentException("Json input was empty or null.");
    }

    private static Object jsonConverter(final String type, final String jsonString) throws JsonProcessingException {
        return switch (type) {
            case "binary" -> MAPPER.readValue(jsonString, byte[].class);
            case "boolean" -> MAPPER.readValue(jsonString, Boolean.class);
            case "date" -> {
                final ZonedDateTime date = MAPPER.readValue(jsonString, ZonedDateTime.class);
                // Convert The ZonedDateTime ZoneId to UTC to make sure that the timezone is supported by the python pytz module.
                yield Instant.ofEpochSecond(date.toEpochSecond()).atZone(ZoneId.of("UTC"));
            }
            case "real" -> MAPPER.readValue(jsonString, Double.class);
            case "latLong" -> MAPPER.readValue(jsonString, LatLong.class);
            case "integer" -> MAPPER.readValue(jsonString, Long.class);
            case "string" -> MAPPER.readValue(jsonString, String.class);
            case "vector" -> MAPPER.readValue(jsonString, Vector.class);
            case "list" -> MAPPER.readValue(jsonString, List.class);
            case "map" -> MAPPER.readValue(jsonString, Map.class);
            default ->
                throw new IllegalArgumentException("Unsupported argument type");
        };
    }

    /**
     * Invoke a transformer with a given request. This method is isolated such that it can be easily used for mocking.
     * @param request The transformer request containing the transformer name and arguments we want to invoke it with.
     * @return The returned response from the transformer.
     */
    protected RpcTransformerResponse invokeTransformer(final RpcTransformerRequest request) {
        return _blockingPluginStub.transform(request);
    }

    private void withUnwrappedExceptions(final Runnable runnable) throws IOException {
        try {
            runnable.run();
        }
        catch (final StatusRuntimeException e) {
            LOG.debug("Got a gRPC StatusRuntimeException (status: {}), logging this here, since the unwrapped exception is rethrown from here", e.getStatus(), e);
            // rethrow the unwrapped exception for better understandable error messages client-side
            throw unwrap(e);
        }
    }

    /**
     * Gets the client's connection target, a String which consists of a host and port. For example: localhost:8999.
     *
     * @return a host and port combined into an authority string
     */
    public String getTarget() {
        return _target;
    }

    /**
     * Returns a handler for orchestrating the processing of a trace over gRPC.
     * <p>
     * This method is exposed as protected method so we can hook into this object using unit tests.
     *
     * @param trace the trace to process
     * @param dataContext the data context to process
     * @param searcher the trace searcher
     * @param stream the outbound communication stream
     * @return a gRPC handler
     */
    protected ProtocolHandler handler(final ClientTrace trace, final ClientDataContext dataContext,
                                      final TraceSearcher searcher, final ReplyStream stream) {
        return new ProtocolHandler(stream, adapter(trace, dataContext, searcher));
    }

    /**
     * Gets the current connectivity state. Note the result may soon become outdated.
     *
     * @return the state of the connection
     */
    ConnectivityState getState() {
        return _channel.getState(true);
    }

    /**
     * Returns an adapter that translates gRPC protocol messages to clean API calls.
     * <p>
     * This method is exposed as protected method so we can hook into this object using unit tests.
     *
     * @param trace the trace on which to make the API calls
     * @param dataContext the data context on which to make the API calls
     * @param searcher the searcher on which to make the API calls
     * @return gRPC adapter object
     */
    protected ExtractionPluginGrpcAdapter adapter(final ClientTrace trace, final ClientDataContext dataContext, final TraceSearcher searcher) {
        return new ExtractionPluginGrpcAdapter(trace, dataContext, new ExtractionPluginDataReader(dataContext), searcher);
    }

    @Override
    public void close() throws InterruptedException {
        _channel.shutdownNow().awaitTermination(5, TimeUnit.SECONDS);
    }

    /**
     * Improve exceptions thrown by gRPC for extraction plugin users, by unpacking io.grpc.StatusRuntimeException
     * and by (re)throwing the original exception.
     *
     * @param e original exception
     * @return a new IllegalStateException if the unwrapped exception is not a IOException or RuntimeException
     * @throws IOException if the unwrapped exception is a IOException
     */
    private RuntimeException unwrap(final StatusRuntimeException e) throws IOException {
        if (e.getCause() == null) {
            // nothing to unwrap, just rethrow the original exception
            throw e;
        }
        else if (e.getCause() instanceof StatusRuntimeException) {
            return unwrap((StatusRuntimeException) e.getCause());
        }
        else if (e.getCause() instanceof RuntimeException) {
            throw (RuntimeException) e.getCause();
        }
        else if (e.getCause() instanceof IOException) {
            throw (IOException) e.getCause();
        }
        else if (e.getCause() instanceof Error) {
            throw (Error) e.getCause();
        }
        else {
            // unexpected, since plugin.process() does not throw checked exceptions other than IOException
            // however keep a safeguard in case checked exceptions make it here.
            return new IllegalStateException(e.getCause());
        }
    }
}
