package dulab.adap.common.types;

import com.google.common.collect.Range;
import org.apache.commons.lang3.mutable.MutableDouble;

import javax.annotation.Nonnull;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;

/**
 * @author Du-Lab Team <dulab.binf@gmail.com>
 */
public class IntervalTree<T extends IntervalTree.Item>
{
    /**
     * Interface for any object with an interval
     */
    public interface Item {
        Range<Double> getInterval();
    }

    /**
     * This class represents a node of the interval tree
     */
    private class Node
    {
        /** An item associated with the node */
        final T item;

        /** The start of the item's interval */
        double leftValue;

        /** The maximum value across the ends of all intervals associated with the current node and its children */
        double rightValue;

        /** Left branch of the node */
        Node left = null;

        /** Right branch of the node */
        Node right = null;

        /**
         * Creates an instance of Node. Sets leftValue and rightValue to be the start and end of the
         * associated interval.
         *
         * @param item An item, the node is associated with.
         */
        Node(@Nonnull T item) {
            this.item = item;
            this.leftValue = item.getInterval().lowerEndpoint();
            this.rightValue = item.getInterval().upperEndpoint();
        }
    }

    /** Root of the interval tree */
    private final Node root;

    /**
     * Creates an instance of the interval tree.
     * @param items array of items, the interval tree is built for
     */
    public IntervalTree(@Nonnull T[] items)
    {
        // Sort arrays based on the start values of intervals (needed to build a balanced tree)
        Arrays.sort(items, new Comparator<Item>() {
            @Override
            public int compare(Item o1, Item o2) {
                return Double.compare(
                        o1.getInterval().lowerEndpoint(),
                        o2.getInterval().lowerEndpoint());
            }
        });
        
        root = buildTree(items, new MutableDouble());
    }

    /**
     * Builds the interval tree, For a given array of items:
     *   > Find the median item
     *   > Create a node with the median item
     *   > Recursively build a tree for the left and right halves of the array
     *   > Set the node's rightValue to be the maximum of three numbers:
     *     (i)   the end value of the interval associated with the node
     *     (ii)  the right value returned from the left branch
     *     (iii) the right value returned from the right branch
     *
     * @param items array of items, the interval tree is built for
     * @param rightValue the right value of the current node
     * @return the current node
     */
    private Node buildTree(@Nonnull T[] items, @Nonnull MutableDouble rightValue)
    {
        final int numItems = items.length;

        if (numItems == 0) return null;

        Node node = new Node(items[numItems / 2]);
        MutableDouble maxRightValue = new MutableDouble(0.0);

        node.left = buildTree(Arrays.copyOfRange(items, 0, numItems / 2), maxRightValue);
        node.right = buildTree(Arrays.copyOfRange(items, numItems / 2 + 1, numItems), maxRightValue);

        node.rightValue = java.lang.Math.max(maxRightValue.getValue(), node.rightValue);
        rightValue.setValue(node.rightValue);

        return node;
    }

    /**
     * Finds all interval in the tree that overlap with the search interval
     * @param interval interval used for the search
     * @return list of items whose intervals overlap with the search interval
     */
    @Nonnull
    public List<T> search(@Nonnull Range<Double> interval)
    {
        List<T> searchResult = new ArrayList<>();

        searchStep(interval, root, searchResult);

        return searchResult;
    }

    /**
     * Recursive search if items whose intervals overlap with the search interval. For a given node:
     *   > Skip the current node and its children if the node's right value is on the left from the search interval
     *   > Skip the current node and its right children if the node's left value is on the right from the search interval
     *   > Otherwise, check the current node and then search its children
     *
     * @param searchInterval interval used for the search
     * @param node current node of the interval tree
     * @param searchResult list of items whose intervals overlap with the search interval
     */
    private void searchStep(@Nonnull Range<Double> searchInterval, Node node, @Nonnull List<T> searchResult)
    {
        // Don't search the node that doesn't exist
        if (node == null) return;

        // If node's right value is on the left from the search interval, this node and its children
        // do not overlap with the search interval
        if (node.rightValue < searchInterval.lowerEndpoint()) return;

        // Search the left children
        searchStep(searchInterval, node.left, searchResult);

        // Check this node
        if (searchInterval.isConnected(node.item.getInterval()))
            searchResult.add(node.item);

        // If node's left value is on the right of the search interval, than its right children
        // do not overlap with the search interval
        if (node.leftValue > searchInterval.upperEndpoint()) return;

        // Search the right children
        searchStep(searchInterval, node.right, searchResult);
    }
}
