package org.hansken.plugin.extraction.runtime.grpc.server.proxy;

import static java.lang.Integer.min;
import static java.lang.Math.toIntExact;

import static org.hansken.plugin.extraction.util.ArgChecks.argNotNull;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean;

import org.apache.commons.lang3.exception.ExceptionUtils;
import org.hansken.extraction.plugin.grpc.RpcBatchUpdate;
import org.hansken.extraction.plugin.grpc.RpcBeginChild;
import org.hansken.extraction.plugin.grpc.RpcBeginDataStream;
import org.hansken.extraction.plugin.grpc.RpcFinish;
import org.hansken.extraction.plugin.grpc.RpcFinishChild;
import org.hansken.extraction.plugin.grpc.RpcFinishDataStream;
import org.hansken.extraction.plugin.grpc.RpcPartialFinishWithError;
import org.hansken.extraction.plugin.grpc.RpcProfile;
import org.hansken.extraction.plugin.grpc.RpcSearchResult;
import org.hansken.extraction.plugin.grpc.RpcSync;
import org.hansken.extraction.plugin.grpc.RpcWriteDataStream;
import org.hansken.plugin.extraction.api.BatchSearchResult;
import org.hansken.plugin.extraction.api.DataContext;
import org.hansken.plugin.extraction.api.SearchResult;
import org.hansken.plugin.extraction.api.Trace;
import org.hansken.plugin.extraction.api.TraceSearcher.SearchScope;
import org.hansken.plugin.extraction.api.transformations.DataTransformation;
import org.hansken.plugin.extraction.runtime.grpc.common.Pack;
import org.hansken.plugin.extraction.runtime.grpc.common.Unpack;

import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import com.google.protobuf.Message;

import io.grpc.Status;
import io.grpc.stub.StreamObserver;

/**
 * Facade through which the server can take and receive calls over RPC. It translates method calls
 * to the necessary gRPC messages and blocks for response if necessary.
 */
public class GrpcFacade {

    // the message wire format takes up some bytes as well, choose a safe value here for now
    private static final int MESSAGE_OVERHEAD_SIZE = 1024;
    // overhead per buffered action, chosen a safe value again
    private static final int PER_ACTION_OVERHEAD_SIZE = 8;
    private static final int MINIMUM_MESSAGE_SIZE = 1024 * 1024;
    private static final int SEARCH_REQUEST_COUNT_LIMIT = 50; // considering max gRPC msg size of 64Mb
    private static final int MAX_UNSYNCED_REQUESTS = 1000;

    private final int _bufferedMessageSize;
    private final int _maximumReadChunkSize;
    private final int _maximumWriteChunkSize;

    private final BlockingQueue<Any> _incomingMessages;
    private final StreamObserver<Any> _outgoingMessages;
    private final AtomicBoolean _expectingResponse;
    private final List<Any> _bufferedUpdates;

    private long _currentlyBufferedSize;
    private long _start;

    private int _unsyncedRequests;

    /**
     * Create a new {@link GrpcFacade RPC facade}, which will send the request on the given {@link StreamObserver},
     * and if necessary, wait for reply on given {@link BlockingQueue}.
     *
     * @param incomingMessages queue where results of a request are pushed on
     * @param outgoingMessages observer on which to publish requests
     * @param maximumMessageSize the maximum size of the wire format of a single RPC message
     * @throws IllegalArgumentException if the maximum message size is less than {@link GrpcFacade#MINIMUM_MESSAGE_SIZE}
     */
    public GrpcFacade(final BlockingQueue<Any> incomingMessages, final StreamObserver<Any> outgoingMessages, final int maximumMessageSize) {
        _incomingMessages = argNotNull("incomingMessages", incomingMessages);
        _outgoingMessages = argNotNull("outgoingMessages", outgoingMessages);
        if (maximumMessageSize < MINIMUM_MESSAGE_SIZE) {
            throw new IllegalArgumentException(
                "maximum message size is too small: " + maximumMessageSize +
                    ", should be at least" + MINIMUM_MESSAGE_SIZE
            );
        }
        _bufferedMessageSize = maximumMessageSize;
        _maximumReadChunkSize = maximumMessageSize - MESSAGE_OVERHEAD_SIZE;
        _maximumWriteChunkSize = _maximumReadChunkSize - PER_ACTION_OVERHEAD_SIZE;
        _expectingResponse = new AtomicBoolean(false);
        _bufferedUpdates = new ArrayList<>(64);
    }

    /**
     * Read from the data sequence currently processed by the trace, see {@link DataContext#data()}.
     * <p>
     * <strong>Note:</strong> callers must ensure that the trace data contains at least {@code count} bytes
     * from given {@code position}, else the result if this call is undefined.
     *
     * @param position the position to seek to and read from
     * @param count the number of bytes to read
     * @param traceUid the uid of the trace to read from
     * @param type the type of the data stream to read from
     * @return the read bytes
     */
    public byte[] readFromTraceData(final long position, final int count, final String traceUid, final String type) {
        final byte[] buffer = new byte[count];
        int remaining = count;
        while (remaining > 0) {
            final long currentPosition = position + (count - remaining);
            final int toRead = min(remaining, _maximumReadChunkSize);
            final byte[] read = Unpack.bytes(call(Pack.readParameters(currentPosition, toRead, traceUid, type)));
            System.arraycopy(read, 0, buffer, count - remaining, read.length);
            remaining -= toRead;
        }
        return buffer;
    }

    /**
     * Add types and properties to the currently processed trace.
     *
     * @param id the id of the trace to enrich
     * @param types the types to add
     * @param properties the properties to add
     * @param tracelets the tracelets to add
     * @param transformations the transformations to add
     */
    public void enrichTrace(final String id, final Set<String> types, final Map<String, Object> properties, final List<Trace.Tracelet> tracelets, final Map<String, List<DataTransformation>> transformations) {
        buffer(Any.pack(Pack.traceEnrichment(id, types, properties, tracelets, transformations)));
    }

    /**
     * Signal that we are in the scope of a newly created child {@link Trace}.
     *
     * @param id the id of the new child
     * @param name the name of the new child
     */
    public void beginChild(final String id, final String name) {
        buffer(Any.pack(RpcBeginChild.newBuilder().setId(id).setName(name).build()));
    }

    /**
     * Signal that we exited the last created child {@link Trace} scope.
     *
     * @param id the id of the closed child
     */
    public void finishChild(final String id) {
        buffer(Any.pack(RpcFinishChild.newBuilder().setId(id).build()));
    }

    /**
     * Signal that we start writing a data stream of given type.
     *
     * @param id the id of the trace to write the data to
     * @param dataType type of the data stream we will be writing to (raw, html...)
     */
    public void beginWritingData(final String id, final String dataType) {
        buffer(Any.pack(RpcBeginDataStream.newBuilder().setTraceId(id).setDataType(dataType).build()));
    }

    /**
     * Send a data chunk of the current data stream to the client. Will send multiple messages if
     * {@code data.length > _maximumWriteChunkSize}. Multiple calls to this method are regarded as sequential
     * data writes.
     *
     * @param id the id of the trace to write the data to
     * @param dataType the stream data type
     * @param data the chunk to write
     * @param offset the offset in the buffer to write from
     * @param length the amount of data from the buffer to write
     */
    public void writeData(final String id, final String dataType, final byte[] data, final int offset, final int length) {
        int remaining = length;
        while (remaining > 0) {
            final long currentPosition = (long) offset + length - remaining;
            final int from = toIntExact(currentPosition);
            final int toWrite = min(remaining, _maximumWriteChunkSize);
            buffer(Any.pack(writeDataMessage(id, dataType, data, from, toWrite)));
            buffer(rpcSync()); // do a sync after every write
            resetSync();
            remaining -= toWrite;
        }
    }

    private RpcWriteDataStream writeDataMessage(
        final String id,
        final String dataType,
        final byte[] buffer,
        final int offset,
        final int length) {
        return RpcWriteDataStream.newBuilder()
            .setTraceId(id)
            .setDataType(dataType)
            .setData(ByteString.copyFrom(buffer, offset, length))
            .build();
    }

    /**
     * Signal that we finished writing to the data stream of given type.
     *
     * @param id the id of the trace we were writing data to
     * @param dataType type of the data stream we were writing to (raw, html...)
     */
    public void finishWritingData(final String id, final String dataType) {
        buffer(Any.pack(RpcFinishDataStream.newBuilder().setTraceId(id).setDataType(dataType).build()));
    }

    /**
     * Asks the client to perform a search request in Hansken to retrieve traces.
     *
     * @param query the query to perform
     * @param count maximum traces to return
     * @param scope scope to limit the search to (project or image)
     * @return a SearchResult containing the total number of traces and the traces
     * @throws IllegalArgumentException if the request's count exceeds the limit of {@code SEARCH_REQUEST_COUNT_LIMIT}
     */
    public SearchResult searchTraces(final String query, final int count, final SearchScope scope) {
        // TODO: HANSKEN-15066 support search results of indefinite size
        if (count > SEARCH_REQUEST_COUNT_LIMIT) { // safety check to fit answers into one gRPC message
            throw new IllegalArgumentException("search request count must not exceed the limit of " + SEARCH_REQUEST_COUNT_LIMIT);
        }
        final Any anyResult = call(Pack.searchRequest(query, count, scope));
        final RpcSearchResult rpcResult = Unpack.any(anyResult, RpcSearchResult.class);
        final BatchSearchResult result = new BatchSearchResult(rpcResult.getTotalResults());
        // Convert all retrieved rpctraces to traceproxies
        result.setTraces(
            rpcResult
                .getTracesList()
                .stream()
                .map(trace -> SearchTraceProxy.fromRpc(trace, this))
                .toArray(SearchTraceProxy[]::new)
        );

        return result;
    }

    /**
     * Receive a response for a request and handle it.
     *
     * @param any the response received
     * @throws IllegalStateException if no response was expected
     */
    public void handleResponse(final Any any) {
        if (!_expectingResponse.get()) {
            throw new IllegalStateException("unexpected message received: " + any);
        }
        if (!_incomingMessages.offer(any)) {
            throw new IllegalStateException("could not handle incoming message: " + any);
        }
    }

    /**
     * Signal completion of the stream due to an error.
     *
     * @param statusCode the error code
     * @param t the cause of the error
     */
    public void onError(final Status.Code statusCode, final Throwable t) {
        _outgoingMessages.onError(Pack.asStatusRuntimeException(statusCode, t));
    }

    /**
     * Signal completion of the stream to the outgoing observer, see {@link StreamObserver#onCompleted()}.
     */
    public void onCompleted() {
        _outgoingMessages.onCompleted();
    }

    /**
     * Send a {@link RpcFinish finish} message to the client, containing a set of actions to execute.
     *
     * @param duration The duration of the process()-execution in seconds (to be included in the execution profile)
     */
    public void finishProcessing(final double duration) {
        final RpcFinish update = RpcFinish.newBuilder()
            .setUpdate(RpcBatchUpdate.newBuilder().addAllActions(_bufferedUpdates).build())
            .setProfile(profile(duration))
            .build();
        _outgoingMessages.onNext(Any.pack(update));
    }

    /**
     * This method sends an {@link RpcPartialFinishWithError} message to the client in case an error occurs, containing
     * a partial(or empty) set of actions that have been cached so far, and a description of the error.
     *
     * @param t the cause of the error
     * @param duration The duration of the process()-execution in seconds (to be included in the execution profile)
     */
    public void processPartialResultOrError(final Throwable t, final double duration) {
        final Status.Code statusCode = _currentlyBufferedSize == 0 ? Status.Code.CANCELLED : Status.Code.DATA_LOSS;
        _outgoingMessages.onNext(Any.pack(RpcPartialFinishWithError.newBuilder()
            .addAllActions(_bufferedUpdates)
            .setStatusCode(statusCode.name())
            .setErrorDescription(ExceptionUtils.getStackTrace(t))
            .setProfile(profile(duration))
            .build()));
    }

    private RpcProfile profile(final double duration) {
        return RpcProfile.newBuilder()
            .putProfileDoubles("duration", duration)
            .build();
    }

    private void buffer(final Any action) {
        final int actionSize = action.getSerializedSize();
        final long estimatedUpdateMessageSize =
            _currentlyBufferedSize + actionSize + PER_ACTION_OVERHEAD_SIZE + MESSAGE_OVERHEAD_SIZE;

        if (estimatedUpdateMessageSize > _bufferedMessageSize) {
            final RpcBatchUpdate update = RpcBatchUpdate.newBuilder()
                .addAllActions(_bufferedUpdates)
                .build();
            _bufferedUpdates.clear();
            _currentlyBufferedSize = 0;
            call(Any.pack(update));
            syncIfReady(); // pause for a sync after every batch
        }
        _currentlyBufferedSize += actionSize + PER_ACTION_OVERHEAD_SIZE;
        _bufferedUpdates.add(action);
        _unsyncedRequests++;
    }

    private Optional<Any> syncIfReady() {
        if (_unsyncedRequests > MAX_UNSYNCED_REQUESTS) {
            resetSync();
            return Optional.of(call(rpcSync()));
        }
        return Optional.empty();
    }

    private void resetSync() {
        _unsyncedRequests = 0;
    }

    private static Any rpcSync() {
        return Any.pack(RpcSync.getDefaultInstance());
    }

    private Any call(final Message message) {
        return call(Any.pack(message));
    }

    private Any call(final Any message) {
        try {
            if (_expectingResponse.getAndSet(true)) {
                throw new IllegalStateException("expecting response for previous request, can not send one in parallel");
            }
            _outgoingMessages.onNext(message);
            return _incomingMessages.take();
        }
        catch (final InterruptedException e) {
            Thread.currentThread().interrupt();
            throw new IllegalStateException(e);
        }
        finally {
            _expectingResponse.set(false);
        }
    }
}
