package org.kink_lang.kink.internal.ovis;

import java.lang.invoke.MethodHandles;
import java.lang.invoke.VarHandle;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Collection;
import java.util.List;
import java.util.stream.IntStream;
import java.util.stream.LongStream;

import javax.annotation.Nullable;

import org.kink_lang.kink.internal.control.Control;
import org.kink_lang.kink.internal.sym.SymRegistryImpl;

/**
 * Mapping from syms to indexes of own vars.
 */
public class OwnVarIndexes {

    /** Pairs (high = from index, low = to index) for each hash code. */
    private final long[] rangePairs;

    /**
     * Each element is (high=index, low=symHandle).
     * Sorted by hash(symHandle).
     */
    private final long[] indexHandlePairs;

    /** Transition table. */
    @SuppressWarnings("PMD.ImmutableField") // written by TT_VH
    private TransitionTable transitionTable;

    /** Whether it has a sym of preloaded funs. */
    private final boolean containsPreloadedVar;

    /** The var handle for transitionTable. */
    private static final VarHandle TT_VH = Control.runWrappingThrowable(
            () -> MethodHandles.lookup()
            .findVarHandle(OwnVarIndexes.class, "transitionTable", TransitionTable.class));

    /**
     * Constructs a mapping.
     */
    private OwnVarIndexes(
            long[] rangePairs,
            long[] indexHandlePairs,
            boolean containsPreloadedVar) {
        this.rangePairs = rangePairs;
        this.indexHandlePairs = indexHandlePairs;
        this.transitionTable = new TransitionTable(new int[0], new OwnVarIndexes[0]);
        this.containsPreloadedVar = containsPreloadedVar;
        setupTransitionForOwnSyms();
    }

    /**
     * Builds an a mapping.
     */
    static OwnVarIndexes build(Collection<Long> unsortedIndexHandlePairs, int bucketCount) {
        bucketCount = Math.max(1, bucketCount);
        long[] indexHandlePairs = sort(unsortedIndexHandlePairs, bucketCount);
        long[] rangePairs = makeRangePairs(indexHandlePairs, bucketCount);
        boolean containsPreloadedVar = containsPreloadedVar(unsortedIndexHandlePairs);
        return new OwnVarIndexes(rangePairs, indexHandlePairs, containsPreloadedVar);
    }

    /**
     * Sorts index-handle pairs by each hash code.
     */
    private static long[] sort(Collection<Long> unsortedIndexHandlePairs, int bucketCount) {
        Comparator<Long> compare = Comparator.comparing(pair -> {
            int symHandle = getLow(pair);
            return hash(symHandle, bucketCount);
        });
        return unsortedIndexHandlePairs.stream()
            .sorted(compare)
            .mapToLong(i -> i)
            .toArray();
    }

    /**
     * Makes from-to ranges for each bucket.
     */
    private static long[] makeRangePairs(long[] indexHandlePairs, int bucketCount) {
        long[] rangePairs = new long[bucketCount];
        int prevFrom = -1;
        int prevHash = -1;
        for (int i = 0; i < indexHandlePairs.length; ++ i) {
            int symHandle = getLow(indexHandlePairs[i]);
            int hash = hash(symHandle, bucketCount);
            int from = hash == prevHash ? prevFrom : i;
            int to = i + 1;
            rangePairs[hash] = makePair(from, to);
            prevFrom = from;
            prevHash = hash;
        }
        return rangePairs;
    }

    /**
     * Makes a pair of ints.
     */
    private static long makePair(int high, int low) {
        return ((long) high) << 32 | low;
    }

    /**
     * Returns the higher int.
     */
    private static int getHigh(long pair) {
        return (int) (pair >>> 32);
    }

    /**
     * Returns the lower int.
     */
    private static int getLow(long pair) {
        return (int) pair;
    }

    /**
     * Add edges of own sym handles to the transition table.
     */
    private void setupTransitionForOwnSyms() {
        TransitionTable origTt = this.transitionTable;
        TransitionTable tt = origTt;
        for (long pair : this.indexHandlePairs) {
            int symHandle = getLow(pair);
            tt = tt.plus(symHandle, this);
        }
        setTransitionTable(origTt, tt);
    }

    /**
     * Builds an empty set.
     *
     * @return an empty set.
     */
    public static OwnVarIndexes buildEmpty() {
        return build(List.of(), 1);
    }

    /**
     * Returns the set adding symHandle.
     *
     * @param symHandle sym handle added to the ovis.
     * @return the set adding symHandle.
     */
    public OwnVarIndexes with(int symHandle) {
        while (true) {
            TransitionTable tt = this.transitionTable;
            @Nullable OwnVarIndexes foundOvs = tt.oviWith(symHandle);
            if (foundOvs != null) {
                return foundOvs;
            }

            List<Long> newIndexHandlePairs = LongStream
                .concat(LongStream.of(this.indexHandlePairs),
                        LongStream.of(makePair(size(), symHandle)))
                .mapToObj(i -> i)
                .toList();
            OwnVarIndexes newOvs = build(newIndexHandlePairs, size() + 1);
            setTransitionTable(tt, tt.plus(symHandle, newOvs));
        }

    }

    /**
     * Compare-and-sets new transition table.
     */
    private void setTransitionTable(TransitionTable origTt, TransitionTable newTt) {
        TT_VH.compareAndSet(this, origTt, newTt);
    }

    /**
     * Returns the index of the sym handle.
     *
     * @param symHandle the sym handle.
     * @return the index of the sym handle, or -1 if absent.
     */
    public int getIndex(int symHandle) {
        int hash = hash(symHandle);
        long range = this.rangePairs[hash];
        int from = getHigh(range);
        int to = getLow(range);
        for (int i = from; i < to; ++ i) {
            long indexHandlePair = this.indexHandlePairs[i];
            if (symHandle == getLow(indexHandlePair)) {
                return getHigh(indexHandlePair);
            }
        }
        return -1;
    }

    /**
     * Hash code of symHandle.
     */
    private int hash(int symHandle) {
        return hash(symHandle, this.rangePairs.length);
    }

    /**
     * Hash code of symHandle with the given bucket count.
     */
    private static int hash(int symHandle, int bucketCount) {
        return Integer.remainderUnsigned(symHandle, bucketCount);
    }

    /**
     * Whether this does not have any sym handle.
     *
     * @return whether this does not have any sym handle.
     */
    public boolean isEmpty() {
        return size() == 0;
    }

    /**
     * Size of the var sym set.
     */
    private int size() {
        return this.indexHandlePairs.length;
    }

    /**
     * Returns the set of sym handles.
     *
     * @return the set of sym handles.
     */
    public List<Integer> getSymHandles() {
        int[] symHandles = new int[this.indexHandlePairs.length];
        for (long pair : this.indexHandlePairs) {
            symHandles[getHigh(pair)] = getLow(pair);
        }
        return IntStream.of(symHandles).mapToObj(i -> i).toList();
    }

    /**
     * Returns whether it has the sym of a preloaded var.
     *
     * @return whether it has the sym of a preloaded var.
     */
    public boolean containsPreloadedVar() {
        return this.containsPreloadedVar;
    }

    /**
     * Returns whether indexHandlePairs has the sym of a preloaded var.
     */
    private static boolean containsPreloadedVar(Collection<Long> indexHandlePairs) {
        for (long pair : indexHandlePairs) {
            int symHandle = getLow(pair);
            if (SymRegistryImpl.isPreloaded(symHandle)) {
                return true;
            }
        }
        return false;
    }

    /**
     * Immutable transition table.
     */
    private static class TransitionTable {

        /** Sym handles. */
        private final int[] symHandles;

        /** Result ovs for the sym handle with the same index. */
        private final OwnVarIndexes[] destinations;

        /**
         * Constructs a table.
         */
        TransitionTable(int[] symHandles, OwnVarIndexes[] destinations) {
            this.symHandles = symHandles;
            this.destinations = destinations;
        }

        /**
         * Finds OwnVarIndexes with the given edge.
         */
        @Nullable
        OwnVarIndexes oviWith(int symHandle) {
            int ind = Arrays.binarySearch(this.symHandles, symHandle);
            return ind < 0
                ? null
                : this.destinations[ind];
        }

        /**
         * Makes a new transition table with the specified edge+node.
         */
        final TransitionTable plus(int newSymHandle, OwnVarIndexes newDestination) {
            int origLength = this.symHandles.length;
            int insertionPoint = ~ Arrays.binarySearch(this.symHandles, newSymHandle);

            int[] newSymHandles = new int[origLength + 1];
            System.arraycopy(this.symHandles, 0, newSymHandles, 0, insertionPoint);
            newSymHandles[insertionPoint] = newSymHandle;
            System.arraycopy(
                    this.symHandles, insertionPoint,
                    newSymHandles, insertionPoint + 1,
                    origLength - insertionPoint);

            OwnVarIndexes[] newDestinations = new OwnVarIndexes[origLength + 1];
            System.arraycopy(this.destinations, 0, newDestinations, 0, insertionPoint);
            newDestinations[insertionPoint] = newDestination;
            System.arraycopy(
                    this.destinations, insertionPoint,
                    newDestinations, insertionPoint + 1,
                    origLength - insertionPoint);
            return new TransitionTable(newSymHandles, newDestinations);

        }

    }

}

// vim: et sw=4 sts=4 fdm=marker
