/*
 * Copyright (C) 2017 Du-Lab Team <dulab.binf@gmail.com>
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
 */

package dulab.adap.workflow.decomposition;

import com.google.common.collect.Range;
import dulab.adap.common.algorithms.machineleanring.HierarchicalIntervalClusterer;
import dulab.adap.datamodel.BetterPeak;
import dulab.adap.datamodel.Chromatogram;
import org.dulab.jsparcehc.*;

import javax.annotation.Nullable;
import java.util.*;
import java.util.logging.Logger;
import java.util.stream.Collectors;

/**
 * Performs a constrained clustering of peaks based on Distance(d1, d2) = 1 - \d1 &cap; d2\ / min(|d1|, |d2|) and
 * the preferable cluster diameter.
 *
 * @author Du-Lab Team dulab.binf@gmail.com
 */
public class RetTimeClusterer {

    private static final Logger LOGGER = Logger.getLogger(RetTimeClusterer.class.getName());

    private static final int MIN_CLUSTER_SIZE = 4;

    private final double maxClusterWidth;

    /**
     * Creates an instance of {@link RetTimeClusterer}
     *
     * @param maxClusterWidth maximum allowed width of clusters
     * @param minPeaks        minimum number of peaks in a cluster
     */
    public RetTimeClusterer(double maxClusterWidth) {
        this.maxClusterWidth = maxClusterWidth;
    }

    /**
     * Performs the hierarchical DBSCAN clustering of peaks ranges (retention-time domains).
     * <p>
     * First, peaks are clustered by DBSCAN with parameters {@code minPeaks} and {@code eps}. Then, if a cluster
     * diameter exceeds {@code maxClusterWidth}, peaks in that cluster are clustered again by DBSCAN with parameters
     * {@code minPeaks} and {@code eps = eps / 2}. The clustering is repeated until all clusters have the preferred
     * diameter or {@code eps < 0.01}
     *
     * @param peaks ranges (retention-time domains) of peaks
     * @return list of {@link Cluster} containing the formed clusters
     */
    public List<Cluster> execute(List<BetterPeak> peaks) {

        long time = System.currentTimeMillis();

        // Pre-cluster peaks based on gaps between them
        List<Cluster> preliminaryClusters = preCluster(peaks);

        // If a cluster is wider then the maxClusterWidth, apply the hierarchical clustering with complete linkage
        List<Cluster> clusters = new ArrayList<>(preliminaryClusters.size());
        for (Cluster c : preliminaryClusters) {

            double retTimeSpan = c.end - c.start;

            if (retTimeSpan > maxClusterWidth)
                clusters.addAll(cluster(c.peaks));
            else
                clusters.add(c);
        }

        // Adjust boundaries of all clusters
        clusters = adjustClusterBoundaries(clusters);

        LOGGER.info(String.format(
                "Retention time clustering is completed in %.3f sec.",
                (System.currentTimeMillis() - time) / 1000.0));

        return clusters;
    }

    /**
     * Cluster peaks based on gaps between peaks: if there is a retention time gap so that there is no peaks in it, then
     * the peaks on the left and right from the gap are split into separate clusters.
     *
     * @param peaks list of peaks
     * @return list of clusters
     */
    private List<Cluster> preCluster(List<BetterPeak> peaks) {

        List<Cluster> clusters = new ArrayList<>();

        // Sort peaks based on their starting retention time
        peaks.sort(Comparator.comparingDouble(BetterPeak::getFirstRetTime));

        // Find gaps between peaks and create clusters
        double currentMaxRetTime = -Double.MAX_VALUE;
        List<BetterPeak> currentPeaks = new ArrayList<>();
        for (BetterPeak peak : peaks) {

            double firstRetTime = peak.getFirstRetTime();
            double lastRetTime = peak.getLastRetTime();

            if (firstRetTime > currentMaxRetTime && !currentPeaks.isEmpty()) {
                clusters.add(new Cluster(currentPeaks));
                currentPeaks = new ArrayList<>();
            }

            if (lastRetTime > currentMaxRetTime)
                currentMaxRetTime = lastRetTime;

            currentPeaks.add(peak);
        }

        if (!currentPeaks.isEmpty())
            clusters.add(new Cluster(currentPeaks));

        return clusters;
    }

    /**
     * Cluster peaks using the sparse hierarchical clustering with the complete linkage
     *
     * @param peaks list of peaks
     * @return list of clusters
     */
    private List<Cluster> cluster(List<BetterPeak> peaks) {

        DistanceMatrix distanceMatrix = new DistanceMatrix(peaks, maxClusterWidth);

        if (distanceMatrix.elements.isEmpty())
            return new ArrayList<>(0);

        SparseHierarchicalClustererV2 clusterer = new SparseHierarchicalClustererV2(
                distanceMatrix, new CompleteLinkage());

        clusterer.cluster((float) maxClusterWidth);

        Map<Integer, Integer> partition = clusterer.getLabels();

        int[] labels = partition.values()
                .stream()
                .distinct()
                .mapToInt(Integer::intValue)
                .toArray();

        List<Cluster> clusters = new ArrayList<>(labels.length);
        for (int label : labels) {

            List<BetterPeak> clusterPeaks = partition.entrySet()
                    .stream()
                    .filter(e -> e.getValue() == label)
                    .map(Map.Entry::getKey)
                    .map(peaks::get)
                    .collect(Collectors.toList());

            if (clusterPeaks.size() > MIN_CLUSTER_SIZE)
                clusters.add(new Cluster(clusterPeaks));
        }

        return clusters;
    }

    private List<Cluster> adjustClusterBoundaries(List<Cluster> clusters) {

        // Sort clusters
        clusters.sort(Comparator.comparingDouble(c -> c.retTime));

        // Construct TIC
        Map<Double, Double> tic = new HashMap<>();
        for (Cluster cluster : clusters) {
            for (BetterPeak peak : cluster.peaks) {
                Chromatogram chromatogram = peak.chromatogram;
                for (int i = 0; i < chromatogram.length; ++i) {
                    double retTime = chromatogram.getRetTime(i);
                    double intensity = chromatogram.getIntensity(i);
                    tic.put(retTime, intensity + tic.getOrDefault(retTime, 0.0));
                }
            }
        }

        // Find the best boundary between adjacent clusters
        for (int i = 1; i < clusters.size(); ++i) {

            Cluster cluster1 = clusters.get(i - 1);
            Cluster cluster2 = clusters.get(i);

            Double boundary = findBoundary(cluster1, cluster2, tic);
            if (boundary != null) {
                cluster1.end = boundary;
                cluster2.start = boundary;
            }
        }

        // Form new clusters
        List<Cluster> adjustedClusters = new ArrayList<>(clusters.size());
        for (Cluster cluster : clusters) {

            List<BetterPeak> peaks = clusters.stream()
                    .flatMap(c -> c.peaks.stream())
                    .filter(p -> cluster.start < p.getRetTime() && p.getRetTime() < cluster.end)
                    .collect(Collectors.toList());

            if (!peaks.isEmpty())
                adjustedClusters.add(new Cluster(peaks, cluster.start, cluster.end));
        }

        return adjustedClusters.stream()
                .filter(c -> c.peaks.size() > MIN_CLUSTER_SIZE)
                .distinct()
                .collect(Collectors.toList());
    }

    @Nullable
    private Double findBoundary(Cluster cluster1, Cluster cluster2, Map<Double, Double> tic)
            throws IllegalArgumentException {


        if (cluster1.retTime > cluster2.retTime)
            throw new IllegalArgumentException("First cluster does not precede to the second cluster");

        double cluster1EndBoundary = cluster1.end;
        double cluster2StartBoundary = cluster2.start;

        // If cluster1's end boundary is less than cluster2's start boundary, then the two clusters do not overlap
        // so we return null
        if (cluster1EndBoundary <= cluster2StartBoundary)
            return null;

        // If cluster1's end boundary is greater than cluster2's start boundary, then the two clusters overlap
        // so we find the boundary that correspond to the lowest total intensity

        // Construct the total-intensity chromatogram
        double windowStart = cluster2StartBoundary;
        double windowEnd = cluster1EndBoundary;

//        Map<Double, Double> tic = new HashMap<>();
//        incrementTic(cluster1.peaks, tic, windowStart, windowEnd);
//        incrementTic(cluster2.peaks, tic, windowStart, windowEnd);
//
//        if (tic.isEmpty())
//            return (windowStart + windowEnd) / 2;

        Map.Entry<Double, Double> middle = new AbstractMap.SimpleImmutableEntry<>((windowStart + windowEnd) / 2, 0.0);

        return tic.entrySet()
                .stream()
                .filter(e -> windowStart < e.getKey() && e.getKey() < windowEnd)
                .min(Comparator.comparing(Map.Entry::getValue))
                .orElse(middle)
                .getKey();
    }

    private void incrementTic(List<BetterPeak> peaks, Map<Double, Double> tic, double start, double end) {
        for (BetterPeak peak : peaks) {
            Chromatogram chromatogram = peak.chromatogram;
            for (int i = 0; i < chromatogram.length; ++i) {
                double retTime = chromatogram.getRetTime(i);
                double intensity = chromatogram.getIntensity(i);
                if (start < retTime && retTime < end)
                    tic.put(retTime, intensity + tic.getOrDefault(retTime, 0.0));
            }
        }
    }

    private double getTotalIntensity(List<BetterPeak> peaks, double retTime) {
        return peaks.stream()
                .map(p -> p.chromatogram)
                .mapToDouble(c -> c.getIntensity(retTime, true))
                .sum();
    }


    /**
     * Stores cluster information: list of peaks, their average retention time, the smallest and highest retention times.
     */
    public static class Cluster {
        /**
         * List of peak ranges
         */
        public final List<BetterPeak> peaks;

        /**
         * The smallest and highest retention times of the peaks
         */
//        public final Range<Double> clusterRange;
        public double start;
        public double end;


        /**
         * Average retention time of the peaks
         */
        public final double retTime;

        /**
         * Creates an instance of {@link Cluster}.
         *
         * @param ranges a list of peak ranges
         */
        public Cluster(List<BetterPeak> peaks) {

            this.peaks = peaks;

            this.retTime = peaks.stream()
                    .mapToDouble(BetterPeak::getRetTime)
                    .average()
                    .orElseThrow(() -> new IllegalStateException("Cannot find the average retention time of a cluster"));

            this.start = peaks.stream()
                    .mapToDouble(BetterPeak::getFirstRetTime)
                    .min()
                    .orElseThrow(() -> new IllegalStateException("Cannot find the minimum retention time of a cluster"));

            this.end = peaks.stream()
                    .mapToDouble(BetterPeak::getLastRetTime)
                    .max()
                    .orElseThrow(() -> new IllegalStateException("Cannot find the maximum retention time of a cluster"));
        }

        public Cluster(List<BetterPeak> peaks, double start, double end) {
            this(peaks);
            this.start = start;
            this.end = end;
        }

        @Override
        public String toString() {
            return String.format("Cluster: size = %d, ret time = %.2f", peaks.size(), retTime);
        }

        @Override
        public boolean equals(Object other) {
            if (other == this)
                return true;
            else if (other instanceof Cluster) {
                Cluster that = (Cluster) other;
                return (this.peaks.size() == that.peaks.size() && this.retTime == that.retTime);
            }
            else
                return false;
        }

        @Override
        public int hashCode() {
            return Objects.hash(this.peaks.size(), this.retTime);
        }
    }

    /**
     * Provides a wrapper for {@link Range} to use in the hierarchical interval clustering.
     */
    public static class Interval extends HierarchicalIntervalClusterer.Interval {
        private final Range<Double> range;

        private final double mz;

        public Interval(Range<Double> range, double mz) {
            this.range = Objects.requireNonNull(range, "Argument Range is null");
            this.mz = mz;
        }

        @Override
        public double getStart() {
            return range.lowerEndpoint();
        }

        @Override
        public double getEnd() {
            return range.upperEndpoint();
        }

        public double getMz() {
            return mz;
        }

        public Range<Double> getRange() {
            return range;
        }
    }


    private static class DistanceMatrix implements Matrix {

        private final int dimension;
        private final List<MatrixElement> elements;
        private int index = 0;

        private DistanceMatrix(List<BetterPeak> peaks, double threshold)
                throws IllegalArgumentException {

            if (peaks == null)
                throw new IllegalArgumentException("List of peaks is null");

            dimension = peaks.size();

            elements = new ArrayList<>();
            for (int i = 0; i < dimension; ++i) {
                for (int j = i + 1; j < dimension; ++j) {
                    float d = distance(peaks.get(i), peaks.get(j));
                    if (d < threshold)
                        elements.add(new MatrixElement(i, j, d));
                }
            }

            init();
        }

        private float distance(BetterPeak peak1, BetterPeak peak2)
                throws IllegalArgumentException {

            if (peak1 == null || peak2 == null)
                throw new IllegalArgumentException("Either peak1 or peak2 is null");

            float delta1 = (float) (peak1.getLastRetTime() - peak2.getFirstRetTime());
            float delta2 = (float) (peak2.getLastRetTime() - peak1.getFirstRetTime());
            return Math.max(delta1, delta2);
        }

        @Override
        public void init() {
            elements.sort(Comparator.comparingDouble(MatrixElement::getValue));
            index = 0;
        }

        @Override
        public MatrixElement getNext() {
            return index < elements.size() ? elements.get(index++) : null;
        }

        @Override
        public int getDimension() {
            return dimension;
        }

        @Override
        public int getNumElements() {
            return elements.size();
        }
    }
}
