package dulab.adap.common.algorithms.machineleanring;

import dulab.adap.common.types.IntervalTree;
import org.apache.commons.math3.exception.NotPositiveException;

import javax.annotation.Nonnull;
import java.util.*;

/**
 * This class implements Constrained DBSCAN clusterer.
 * Each item is defined by its real value and its interval. For each item, we find
 *   > the neighbors,
 *   > reachable items (item B is reachable from item A if there is a link of neighbor items, that connects A and B)
 *
 *  An item, its neighbours, and its reachable items form a cluster. Items that do not have neighbors are defined as noise.
 *
 *  Two items are neighbors if at least one of the following is true:
 *    > Interval of one item is completely contained in interval of the other
 *    > Two intervals overlap and the values of the items are within distance eps
 *
 * @author Du-Lab Team <dulab.binf@gmail.com>
 */
public class PointIntervalClusterer<T extends PointIntervalClusterer.Item>
{
    /**
     * Interface for items to be clustered. The implementation should override the functions
     * getValue() and getInterval().
     */
    public interface Item extends IntervalTree.Item {
        double getValue();
    }

    /** Maximum radius of the neighborhood to be considered */
    private final double eps;

    /** Minimum number of points needed for a cluster */
    private final int minItems;

    /** Status of an item during the clustering process */
    private enum PointStatus {
        NOISE, // The item is considered to be noise
        PART_OF_CLUSTER // The item is already part of a cluster
    }

    /**
     * Creates a new instance of a PointIntervalClusterer
     * @param eps maximum radius of neighborhood to be considered
     * @param minItems Minimum number of items needed to form a cluster
     */
    public PointIntervalClusterer(final double eps, final int minItems)
            throws NotPositiveException
    {
        if (eps < 0.0) throw new NotPositiveException(eps);
        if (minItems < 0) throw new NotPositiveException(minItems);

        this.eps = eps;
        this.minItems = minItems;
    }

    /**
     * Clusters items together
     * @param items items to be clustered
     * @return list of clusters
     */
    @Nonnull
    public List<List<T>> cluster(@Nonnull final T[] items)
    {
        IntervalTree<T> intervalTree = new IntervalTree<>(items);

        final List<List<T>> clusters = new ArrayList<>();
        final Map<T, PointStatus> visited = new HashMap<>();

        for (final T item : items)
        {
            if (visited.get(item) != null) continue;

            final List<T> neighbors = getNeighbors(item, intervalTree);

            if (neighbors.size() < minItems)
                visited.put(item, PointStatus.NOISE);
            else {
                final List<T> cluster = new ArrayList<>();
                expandCluster(item, neighbors, cluster, intervalTree, visited);
                clusters.add(cluster);
            }
        }
        return clusters;
    }

    /**
     * Returns a list of neighbors of the item. Each item is described by its interval and a value.
     * Two items are neighbors if one of the following is true:
     *   > One item's interval is completely contained within the other's interval
     *   > Their intervals overlap and the absolute difference between items' values is less then eps
     *
     * @param item the item to look for
     * @param tree possible neighbors
     * @return the list of neighbors
     */
    @Nonnull
    private List<T> getNeighbors(final T item, @Nonnull final IntervalTree<T> tree)
    {
        final List<T> neighbors = new ArrayList<>();

        for (final T neighbor : tree.search(item.getInterval()))
        {
            double overlapStart = Math.max(item.getInterval().lowerEndpoint(), neighbor.getInterval().lowerEndpoint());
            double overlapEnd = Math.min(item.getInterval().upperEndpoint(), neighbor.getInterval().upperEndpoint());

            if (overlapStart < overlapEnd) {
                double overlap = overlapEnd - overlapStart;
                double distance = Math.min(
                        1.0 - overlap / (item.getInterval().upperEndpoint() - item.getInterval().lowerEndpoint()),
                        1.0 - overlap / (neighbor.getInterval().upperEndpoint() - neighbor.getInterval().lowerEndpoint()));

                if (distance < eps)
                    neighbors.add(neighbor);
            }
//            if (item.getInterval().encloses(neighbor.getInterval()))
//                neighbors.add(neighbor);
//
//            else if (neighbor.getInterval().encloses(item.getInterval()))
//                neighbors.add(neighbor);
//
//            else if (java.lang.Math.abs(item.getValue() - neighbor.getValue()) < eps)
//                neighbors.add(neighbor);
        }
        return neighbors;
    }

    /**
     * Adds the item, its neighbors, and reachable items to the cluster
     * @param item item to be added to teh cluster
     * @param neighbors neighbors of the item
     * @param cluster the cluster
     * @param tree interval tree containing all the items
     * @param visited map with items statuses (null, PART_OF_CLUSTER, NOISE)
     */
    private void expandCluster(@Nonnull T item,
                               @Nonnull List<T> neighbors,
                               @Nonnull List<T> cluster,
                               @Nonnull IntervalTree<T> tree,
                               @Nonnull Map<T, PointStatus> visited)
    {
        cluster.add(item);
        visited.put(item, PointStatus.PART_OF_CLUSTER);

        List<T> seeds = new ArrayList<>(neighbors);

        for (int i = 0; i < seeds.size(); ++i)
        {
            final T current = seeds.get(i);
            PointStatus status = visited.get(current);

            if (status == null) {
                List<T> currentNeighbors = getNeighbors(current, tree);
                if (currentNeighbors.size() >= minItems)
                    seeds = merge(seeds, currentNeighbors);
            }

            if (status != PointStatus.PART_OF_CLUSTER) {
                cluster.add(current);
                visited.put(current, PointStatus.PART_OF_CLUSTER);
            }
        }
    }

    @Nonnull
    private List<T> merge(@Nonnull final List<T> list1, @Nonnull final List<T> list2) {
        final int size1 = list1.size();
        for (T item2 : list2) {
            boolean found = false;
            for (T item1 : list1)
                if (item1 == item2) {
                    found = true;
                    break;
                }
            if (!found)
                list1.add(item2);
        }
        return list1;
    }
}
