/*
 * Copyright (C) 2018 Du-Lab Team <dulab.binf@gmail.com>
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
 */

package dulab.adap.common.algorithms.machineleanring;

import org.apache.log4j.LogManager;
import org.apache.log4j.Logger;
import smile.clustering.HierarchicalClustering;
import smile.clustering.linkage.CompleteLinkage;
import smile.clustering.linkage.Linkage;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

public class HierarchicalIntervalClusterer<T extends HierarchicalIntervalClusterer.Interval> {

    private static final Logger LOGGER = LogManager.getLogger(HierarchicalIntervalClusterer.class);

    /**
     * This class is used to represent a numerical interval and find a distance between two intervals
     */
    public static abstract class Interval {

        public abstract double getStart();

        public abstract double getEnd();

        public double distance(Interval that) {
            double delta1 = this.getEnd() - that.getStart();
            double delta2 = that.getEnd() - this.getStart();
            return Math.max(delta1, delta2);
        }
    }


    public List<List<T>> cluster(List<T> intervals, double threshold) throws IllegalStateException {

        final int numIntervals = intervals.size();

        // Calcualte the distance matrix
         double[][] proximityMatrix = new double[numIntervals][numIntervals];

        for (int i = 0; i < numIntervals; ++i) {
            for (int j = i + 1; j < numIntervals; ++j) {
                double d = intervals.get(i).distance(intervals.get(j));
                proximityMatrix[i][j] = d;
                proximityMatrix[j][i] = d;
            }
        }

        // Perform clustering
        Linkage linkage = new CompleteLinkage(proximityMatrix);
        HierarchicalClustering clustering = new HierarchicalClustering(linkage);

        int[] partition;
        try {
            partition = clustering.partition(threshold);

        } catch (Exception e) {
            String errorMessage = String.format("Error during clustering of %d intervals", numIntervals);
            LOGGER.info(errorMessage, e);
            throw new IllegalStateException(errorMessage);
        }

        // Form clusters
        int[] labels = Arrays.stream(partition)
                .distinct()
                .toArray();

        List<List<T>> clusters = new ArrayList<>(labels.length);
        for (int label : labels) {
            List<T> cluster = IntStream.range(0, numIntervals)
                    .filter(i -> partition[i] == label)
                    .mapToObj(intervals::get)
                    .collect(Collectors.toList());
            clusters.add(cluster);
        }

        return clusters;
    }
}
