/*
 * Copyright (C) 2018 Du-Lab Team <dulab.binf@gmail.com>
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
 */

package dulab.adap.workflow.decomposition;

import dulab.adap.datamodel.BetterPeak;
import smile.clustering.HierarchicalClustering;
import smile.clustering.linkage.CompleteLinkage;

import javax.annotation.Nonnull;
import java.util.*;
import java.util.function.ToDoubleFunction;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/**
 * @author Du-Lab Team dulab.binf@gmail.com
 */
public class BiGaussianDetector {

//    private static final int MIN_CLUSTER_SIZE = 1;  // 5
    private static final double HWHM_OVER_STD = Math.sqrt(2 * Math.log(2.0));

    public final List<Double> retTimes;
    public List<Peak> peaks;
    public final List<BiGaussian> biGaussians;


    BiGaussianDetector(@Nonnull List<BetterPeak> chromatograms,
                       @Nonnull RetTimeClusterer.Cluster cluster,
                       double tolerance, boolean adjustApexRetTime, int minClusterSize) throws IllegalStateException {

        double startRetTime = cluster.start;
        double endRetTime = cluster.end;
        retTimes = chromatograms.stream()
                .flatMapToDouble(c -> Arrays.stream(c.chromatogram.xs)
                        .filter(x -> startRetTime <= x && x <= endRetTime))
                .distinct()
                .sorted()
                .boxed()
                .collect(Collectors.toList());

        if (retTimes.size() == 0)
            throw new IllegalArgumentException("No retention times in the cluster range");

        peaks = new PeakList(chromatograms,
                cluster.peaks,
                retTimes.stream()
                        .mapToDouble(Double::doubleValue)
                        .toArray(),
                adjustApexRetTime);


        if (peaks.isEmpty())
            throw new IllegalStateException(
                    String.format("No peaks found in the range [%.3f, %.3f]", startRetTime, endRetTime));

        // Create proximity matrix
        double[][] proximity = new double[peaks.size()][peaks.size()];
        for (int i = 0; i < peaks.size(); ++i)
            for (int j = i + 1; j < peaks.size(); ++j) {
                double p = Math.abs(peaks.get(i).apexRetTime - peaks.get(j).apexRetTime);
                proximity[i][j] = p;
                proximity[j][i] = p;
            }

        // Perform clustering
        HierarchicalClustering model = new HierarchicalClustering(new CompleteLinkage(proximity));

        final int[] clusterIndices =
                (proximity.length <= 1 || model.getHeight()[model.getHeight().length - 1] < tolerance)
                        ? new int[peaks.size()] : model.partition(tolerance);

        final int[] labels = Arrays.stream(clusterIndices).distinct().toArray();

        // Construct bi-gaussian for each cluster
        biGaussians = new ArrayList<>(labels.length);
        for (int label : labels) {

            int[] indices = IntStream.range(0, peaks.size())
                    .filter(i -> clusterIndices[i] == label)
                    .toArray();

            if (indices.length < minClusterSize) continue;

            List<Peak> filteredPeaks = Arrays.stream(indices)
                    .mapToObj(peaks::get)
                    .collect(Collectors.toList());

            double retTime = median(filteredPeaks, p -> p.apexRetTime);
            double leftFWHM = Math.max(median(filteredPeaks, p -> p.apexRetTime - p.halfMaxStartRetTime), 1e-6);
            double rightFWHM = Math.max(median(filteredPeaks, p -> p.halfMaxEndRetTime - p.apexRetTime), 1e-6);

            biGaussians.add(
                    new BiGaussian(1.0, retTime, leftFWHM / HWHM_OVER_STD, rightFWHM / HWHM_OVER_STD));
        }
    }

    private static double median(List<Peak> peaks, ToDoubleFunction<Peak> function)
            throws IllegalArgumentException {
        return peaks.stream()
                .mapToDouble(function)
                .sorted()
                .skip(peaks.size() / 2)
                .findFirst()
                .orElseThrow(() -> new IllegalArgumentException("Cannot find median"));
    }
}
