package org.hansken.plugin.extraction.runtime.grpc.server.proxy;

import static java.lang.Math.toIntExact;

import static org.hansken.plugin.extraction.api.TraceSearcher.ALL_SEARCH_RESULTS;

import java.util.ArrayDeque;
import java.util.Deque;
import java.util.List;
import java.util.Spliterators;
import java.util.function.Consumer;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

import org.hansken.extraction.plugin.grpc.RpcSearchResult;
import org.hansken.extraction.plugin.grpc.RpcSearchTrace;
import org.hansken.plugin.extraction.api.SearchResult;
import org.hansken.plugin.extraction.api.SearchScope;
import org.hansken.plugin.extraction.api.SearchSortOption;
import org.hansken.plugin.extraction.api.SearchTrace;
import org.hansken.plugin.extraction.runtime.grpc.common.Pack;
import org.hansken.plugin.extraction.runtime.grpc.common.Unpack;

import com.google.protobuf.Any;

/**
 * A search result that queries the traces in batches, so all results can be retrieved.
 *
 * @implNote Limited to retrieving 100,000 traces due to Elasticsearch limitations.
 * @author Netherlands Forensic Institute
 */
class BatchedSearchResult implements SearchResult {
    private static final int NOT_KNOWN_YET = -1;
    // Considering max gRPC msg size of 64Mb.
    private static final int MAX_BATCH_SIZE = 50;
    // Safety check: Elasticsearch can't handle indexes bigger than 100k.
    private static final int ELASTICSEARCH_INDEX_LIMIT = 100_000;

    private final GrpcFacade _grpcFacade;
    private final String _query;
    private final int _count;
    private final SearchScope _scope;
    private final int _start;
    private final List<SearchSortOption> _sort;

    private BatchedSearchTraceSpliterator _spliterator;
    private long _totalResults;

    BatchedSearchResult(final GrpcFacade grpcFacade, final String query, final int count, final SearchScope scope, final int start, final List<SearchSortOption> sort) {
        _grpcFacade = grpcFacade;
        _query = query;
        _count = count;
        _scope = scope;
        _start = start;
        _sort = sort;

        // Initialize the total results by retrieving the first batch of traces.
        _totalResults = NOT_KNOWN_YET;
        _spliterator = new BatchedSearchTraceSpliterator(_start);
        _spliterator.initializeTotalResults();
    }

    @Override
    public Stream<SearchTrace> getTraces() {
        // Replace a "partially consumed" spliterator with a new instance.
        if (_spliterator.hasAdvanced()) {
            _spliterator = new BatchedSearchTraceSpliterator(_start);
        }
        return StreamSupport.stream(_spliterator, false);
    }

    @Override
    public long getTotalHits() {
        return _totalResults;
    }

    private class BatchedSearchTraceSpliterator extends Spliterators.AbstractSpliterator<SearchTrace> {
        private int _position;
        private int _currentCount;
        private Deque<RpcSearchTrace> _currentBatch;

        BatchedSearchTraceSpliterator(final int start) {
            super(Long.MAX_VALUE, ORDERED | DISTINCT | IMMUTABLE | NONNULL);
            _position = start;
            _currentCount = 0;
            _currentBatch = new ArrayDeque<>();
        }

        @Override
        public boolean tryAdvance(final Consumer<? super SearchTrace> action) {
            if (_position > ELASTICSEARCH_INDEX_LIMIT) {
                throw new IllegalStateException("Search request offset must not exceed the Elasticsearch limit of 100.000.");
            }

            final boolean countNotReached = _count == ALL_SEARCH_RESULTS || _currentCount < _count;
            final boolean hasRemainingTraces = _position < _totalResults;
            if (countNotReached && _currentBatch.isEmpty() && hasRemainingTraces) {
                prepareNextBatch();
            }
            if (!_currentBatch.isEmpty()) {
                action.accept(SearchTraceProxy.fromRpc(_currentBatch.pop(), _grpcFacade));
                _position++;
                _currentCount++;
                return true;
            }
            return false;
        }

        private void prepareNextBatch() {
            final Any anyResult = _grpcFacade.call(Pack.searchRequest(_query, getNumberOfResultsToGet(), _scope, _position, _sort));
            final RpcSearchResult rpcResult = Unpack.any(anyResult, RpcSearchResult.class);

            if (_totalResults == NOT_KNOWN_YET) {
                _totalResults = rpcResult.getTotalResults();
            }
            _currentBatch = new ArrayDeque<>(rpcResult.getTracesList());
        }

        private int getNumberOfResultsToGet() {
            if (_totalResults == NOT_KNOWN_YET) {
                return getFirstBatchSize();
            }

            return getSubsequentBatchSize();
        }

        private int getFirstBatchSize() {
            if (_count != ALL_SEARCH_RESULTS && _count < MAX_BATCH_SIZE) {
                return _count;
            }
            return MAX_BATCH_SIZE;
        }

        private int getSubsequentBatchSize() {
            final long remaining = getRemaining();
            if (remaining > MAX_BATCH_SIZE) {
                return MAX_BATCH_SIZE;
            }
            // Safe to cast to an integer, because this is less than MAX_BATCH_SIZE.
            return toIntExact(remaining);
        }

        private long getRemaining() {
            final long totalToGet = _count != ALL_SEARCH_RESULTS ? _count : _totalResults;
            return totalToGet - _currentCount;
        }

        /**
         * Makes sure the _totalResults of the outer class is filled. The first batch needs to be prepared if it's not
         * done yet, because that is the only way to determine the total number of results.
         */
        private void initializeTotalResults() {
            prepareNextBatch();
        }

        /**
         * Check if the spliterator has advanced, i.e. the action of {@link #tryAdvance(Consumer)} has consumed at
         * least one trace.
         *
         * @return <code>true</code> if the spliterator has advanced, <code>false</code> otherwise
         */
        public boolean hasAdvanced() {
            return _currentCount > 0;
        }
    }
}
