/*
 * Copyright (C) 2018 Du-Lab Team <dulab.binf@gmail.com>
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
 */

package dulab.adap.workflow.simplespectraldeconvolution;

import dulab.adap.common.algorithms.PeakUtils;
import dulab.adap.datamodel.*;
import org.dulab.jsparcehc.CompleteLinkage;
import org.dulab.jsparcehc.MatrixImpl;
import org.dulab.jsparcehc.SparseHierarchicalClustererV2;

import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

public class SimpleSpectralDeconvolution {

    private double processed = 0.0;

    private boolean canceled = false;

    public double getProcessedPercent() {
        return processed;
    }

    public void cancel() {
        canceled = true;
    }

    public List<BetterComponent> run(List<BetterPeak> peaks, Parameters params) {

        final int numPeaks = peaks.size();

        processed = 0.0;
        final double processedStep = 1.0 / 3;

        MatrixImpl distanceMatrix = new MatrixImpl((float) params.peakSimilarityThreshold);
        // Calculate similarity matrix
        for (int i = 0; i < numPeaks; ++i) {
            distanceMatrix.add(i, i, 0F);
            for (int j = i + 1; j < numPeaks; ++j) {

                if (canceled) return new ArrayList<>(0);

                float d = 1 - (float) PeakUtils.similarity(peaks.get(i), peaks.get(j), params.retTimeTolerance);
                d = Math.max(d, 0f);
                d = Math.min(d, 1f);

                distanceMatrix.add(i, j, d);
                distanceMatrix.add(j, i, d);
            }
        }

        processed += processedStep;

        // Perform hierarchical clustering
        SparseHierarchicalClustererV2 clusterer = new SparseHierarchicalClustererV2(distanceMatrix, new CompleteLinkage());
        clusterer.cluster((float) params.peakSimilarityThreshold);
        Map<Integer, Integer> labels = clusterer.getLabels();

        int[] distinctLabels = labels.values()
                .stream()
                .mapToInt(Integer::intValue)
                .distinct()
                .toArray();

        processed += processedStep;

        // Create a list of components
        List<BetterComponent> components = new ArrayList<>(distinctLabels.length);
        for (int label : distinctLabels) {

            if (canceled) return new ArrayList<>(0);

            List<BetterPeak> clusterPeaks = IntStream.range(0, numPeaks)
                    .filter(i -> labels.get(i) == label)
                    .mapToObj(peaks::get)
                    .collect(Collectors.toList());

            if (clusterPeaks.size() >= params.minNumPeaks)
                components.add(getComponent(clusterPeaks));
        }

        processed += processedStep;

        return components;
    }

    private BetterComponent getComponent(List<BetterPeak> peaks) throws IllegalStateException {

        // Choose the highest peak
        BetterPeak highestPeak = null;
        for (BetterPeak peak : peaks)
            if (highestPeak == null || peak.getIntensity() > highestPeak.getIntensity())
                highestPeak = peak;

        if (highestPeak == null)
            throw new IllegalStateException("Cannot find the highest peak.");

        Spectrum spectrum = calculateSpectrum(peaks, highestPeak.getRetTime());
        Chromatogram chromatogram = calculateAverageChromatogram(peaks, highestPeak.getIntensity());

        return new BetterComponent(highestPeak.id, chromatogram, spectrum, highestPeak.info);
    }

    private Spectrum calculateSpectrum(List<BetterPeak> peaks, double retTime) {

        double[] mzValues = new double[peaks.size()];
        double[] intensities = new double[peaks.size()];
        PeakInfo[] peakInfos = new PeakInfo[peaks.size()];
        for (int i = 0; i < peaks.size(); ++i) {
            BetterPeak peak = peaks.get(i);
            mzValues[i] = peak.getMZ();
            intensities[i] = peak.chromatogram.interpolateIntensity(retTime);
            peakInfos[i] = peak.info;
        }

        return new Spectrum(mzValues, intensities, peakInfos);
    }

    private Chromatogram calculateAverageChromatogram(List<BetterPeak> peaks, double height) {

        double[] retTimes = peaks.stream()
                .flatMapToDouble(p -> Arrays.stream(p.chromatogram.xs))
                .distinct()
                .sorted()
                .toArray();

        double[] intensities = Arrays.stream(retTimes)
                .map(t -> peaks.stream()
                        .mapToDouble(p -> p.chromatogram.interpolateIntensity(t))
                        .average()
                        .orElse(0.0))
                .toArray();

        double maxIntensity = Arrays.stream(intensities)
                .max()
                .orElseThrow(() -> new IllegalStateException("Cannot find maximum intensity."));

        double scaleFactor = height / maxIntensity;

        for (int i = 0; i < intensities.length; ++i)
            intensities[i] *= scaleFactor;

        return new Chromatogram(retTimes, intensities);
    }
}
