/* 
 * Copyright (C) 2016 Du-Lab Team <dulab.binf@gmail.com>
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
 */
package dulab.adap.workflow;

import com.google.common.collect.Range;
import dulab.adap.common.algorithms.FeatureTools;
import dulab.adap.common.algorithms.Math;
import dulab.adap.common.algorithms.SymmetricMatrix;
import dulab.adap.common.algorithms.machineleanring.ApacheDBSCANClusteringV2;
import dulab.adap.common.algorithms.machineleanring.Clustering;
import dulab.adap.common.algorithms.machineleanring.Optimization;
import dulab.adap.common.algorithms.machineleanring.OptimizationParameters;
import dulab.adap.datamodel.Component;
import dulab.adap.datamodel.Peak;
import dulab.adap.datamodel.PeakInfo;

import javax.annotation.Nonnull;
import java.util.*;
import java.util.Map.Entry;

/**
 *
 * @author aleksandrsmirnov
 */
public class TwoStepDecomposition {
    
    private double processedPercent = 0.0;
    private double processedStep = 0.0;
    
    public double getProcessedPercent() {
        return processedPercent;
    }
    
    /**
     * Decomposition of peaks
     * 
     * Within each window, we cluster peaks based on their retention time
     * and shape, then we create a component for each cluster
     * 
     * @param params parameters used for decomposition, clustering and gradient
               descent methods
     * @param peaks collection of peaks
     * @return a list of decomposed components
     */

    @Nonnull
    public List <Component> run(
            @Nonnull final TwoStepDecompositionParameters params,
            @Nonnull final List <Peak> peaks)
    {
        List <Component> result = new ArrayList <> ();
        
        processedStep = 1.0 / 3;
        processedPercent = 0.0;
        
        List <Peak> modelPeaks = new ArrayList <> ();

        // -------------------------------------------
        // Cluster peak based on their retention times
        // -------------------------------------------

        List <List <Peak>> filteredClusters = getRetTimeClusters(peaks,
                params.minClusterDistance, params.minClusterSize, params.minClusterIntensity);

        processedPercent += processedStep;
        
        // -------------------------------------
        // Within each cluster, find model peaks
        // -------------------------------------

        final double localProcessedStep = processedStep / filteredClusters.size();
        
        for (List <Peak> cluster : filteredClusters)
        {
            List <Peak> modelPeakCandidates = filterPeaks(cluster,
                    params.useIsShared,
                    params.edgeToHeightRatio,
                    params.deltaToHeightRatio,
                    params.minModelPeakSharpness,
                    params.deprecatedMZValues);
            
            int size = modelPeakCandidates.size();

            // If no model peaks found, select the highest peak in the cluster as a model peak
            if (size == 0) 
            {
                Peak modelPeak = findModelPeak(params.modelPeakChoice, cluster);

                if (modelPeak == null) {
                    System.out.println("WARNING: Decomposition: no model peaks found");
                    continue;
                }

                // --------------------------------------------------------
                // Construct spectrum by combining intensities of all close
                // peaks at specific time
                // --------------------------------------------------------

                NavigableMap <Double, Double> spectrum = new TreeMap <> ();
                int index = modelPeak.getInfo().peakIndex;
                double retTime = modelPeak.getRetTime();

                for (Peak peak : peaks)
                    if (peak.getInfo().leftApexIndex <= index
                            && index <= peak.getInfo().rightApexIndex)
                    {
                        spectrum.put(peak.getMZ(), Math
                                .interpolate(retTime, peak.getChromatogram()));
                    }
                
                result.add(new Component(null, modelPeak, spectrum, null));
            }

            // If there is only one peak left, choose it as a model peak
            else if (size == 1) 
            {
                Peak modelPeak = modelPeakCandidates.get(0);

                modelPeaks.add(modelPeak);
            }

            // If there are more then one model peak candidates...
            else 
            {
                // --------------------------------------------------
                // Cluster model peak candidates base on their shapes
                // --------------------------------------------------

                List <List <Peak>> shapeClusters = getShapeClusters(
                        modelPeakCandidates,
                        params.shapeSimThreshold);

                // ------------------------------------------------
                // Find the peak of maximum sharpness and add it to 
                // modelPeaksInfo
                // ------------------------------------------------

                for (List <Peak> c : shapeClusters)
                {
                    Peak modelPeak = findModelPeak(params.modelPeakChoice, c);
                    modelPeaks.add(modelPeak);
                }
            }
            
            processedPercent += localProcessedStep;
        }
        
        result.addAll(buildComponents(
                modelPeaks,
                mergePeaks(peaks,
                        params.edgeToHeightRatio, params.deltaToHeightRatio),
                params.optimizationParams));
        
        return result;
    }

    private Peak findModelPeak(@Nonnull String modelPeakChoice, List<Peak> peaks) {

        Peak modelPeak = null;

        double max = 0.0;
        for (Peak peak : peaks)
        {
            double value;
            switch (modelPeakChoice) {
                case TwoStepDecompositionParameters.MODEL_PEAK_CHOICE_SHARPNESS:
                    value = FeatureTools.sharpnessYang(peak.getChromatogram());
                    break;

                case TwoStepDecompositionParameters.MODEL_PEAK_CHOICE_INTENSITY:
                    value = peak.getIntensity();
                    break;

                case TwoStepDecompositionParameters.MODEL_PEAK_CHOICE_MZ:
                    value = peak.getMZ();
                    break;

                default: // TwoStepDecompositionParameters.MODEL_PEAK_CHOICE_SHARPNESS
                    value = FeatureTools.sharpnessYang(peak.getChromatogram());
            }

            if (value > max) {
                max = value;
                modelPeak = peak;
            }
        }

        return modelPeak;
    }

    /**
     * Build Component from each Peak, construct spectrum
     *
     * @param modelPeaks list of model peaks
     * @param otherPeaks list of all peaks
     * @param params optimization parameters
     * @return list of components
     */
    
    private List <Component> buildComponents(
            List <Peak> modelPeaks,
            List <Peak> otherPeaks,
            OptimizationParameters params) 
    {
        Map <Peak, NavigableMap <Double, Double>> spectra = new HashMap <> ();
        for (Peak peak : modelPeaks) spectra.put(peak, new TreeMap <Double, Double> ());

        // -----------------------
        // Process each Other Peak
        // -----------------------
                
        List <NavigableMap <Double, Double>> modelChromatograms =
                new ArrayList <> ();
        List <Peak> localModelPeaks = new ArrayList <> ();
        
        NavigableMap <Double, Double> otherChromatogram = new TreeMap <> ();
        
        double localProcessedStep = processedStep / otherPeaks.size();

        for (Peak peak : otherPeaks)
        {   
            int leftBoundary = peak.getInfo().leftApexIndex;
            int rightBoundary = peak.getInfo().rightApexIndex;

            modelChromatograms.clear();
            localModelPeaks.clear();
            otherChromatogram.clear();
            
            // -------------------------------------
            // For each other peak, find model peaks
            // -------------------------------------

            for (Peak modelPeak : modelPeaks)
            {
                if (leftBoundary <= modelPeak.getInfo().peakIndex
                        && modelPeak.getInfo().peakIndex <= rightBoundary)
                {   
                    NavigableMap <Double, Double> c = 
                            new TreeMap <> (modelPeak.getChromatogram());

                    double scale = 1.0 / modelPeak.getIntensity();

                    for (Entry <Double, Double> e : c.entrySet())
                        e.setValue(scale * e.getValue());

                    modelChromatograms.add(c);
                    localModelPeaks.add(modelPeak);
                }
            }
         
            otherChromatogram.putAll(peak.getChromatogram());
            
            double scale = 1.0 / peak.getIntensity();
            
            for (Entry <Double, Double> e : otherChromatogram.entrySet())
                e.setValue(scale * e.getValue());
         
            int modelPeakCount = modelChromatograms.size();
            
            double[] coeff = new double[modelPeakCount];
            for (int i = 0; i < modelPeakCount; ++i) coeff[i] = 1.0;

            // ----------------------
            // Fit peak to modelPeaks
            // ----------------------

            coeff = new Optimization().decompose(otherChromatogram,
                    modelChromatograms, coeff,
                    1e-12 , params.maxIterationCount).coefficients;
            
            for (int i = 0; i < modelPeakCount; ++i)
                spectra.get(localModelPeaks.get(i))
                        .put(peak.getMZ(), coeff[i] * peak.getIntensity());
            
            processedPercent += localProcessedStep;
        }
        
        List <Component> result = new ArrayList <> (modelPeaks.size());
        for (Peak peak : modelPeaks)
            result.add(new Component(null, peak, spectra.get(peak), null));
        
        return result;
    }
    
    /**
     * Filter list of Peaks based on isShared, sharpness, and deprecated m/z-values
     * 
     * @param peaks list of Peak objects
     * @param useIsShared if true, shared peaks will be filtered out
     * @param edgeToHeightRatio is used for calculating isShared
     * @param deltaToHeightRatio is used for calculating isShared
     * @param minSharpness minimum sharpness
     * @param deprecatedMZValues deprecated m/z-values
     * @return list of filterer PeakInfo
     */

    @Nonnull
    public static List <Peak> filterPeaks(@Nonnull List <Peak> peaks,
            boolean useIsShared,
            double edgeToHeightRatio,
            double deltaToHeightRatio,
            double minSharpness, 
            @Nonnull List <Range <Double>> deprecatedMZValues)
    {
        List <Peak> filteredPeaks = new ArrayList <> (peaks.size());
        
        for (Peak peak : peaks)
        {
            NavigableMap <Double, Double> chromatogram = peak.getChromatogram();
            
            if (useIsShared && FeatureTools.isShared(
                    new ArrayList <> (chromatogram.values()),
                    edgeToHeightRatio, deltaToHeightRatio)) continue;
            
            if (FeatureTools.sharpnessYang(chromatogram) < minSharpness)
                continue;
            
            if (inRange(peak.getMZ(), deprecatedMZValues)) continue;
            
            filteredPeaks.add(peak);
        }
        
        return filteredPeaks;
    }
    
    /**
     * Cluster list of PeakInfo based on retention time
     * 
     * @param peaks list of Peaks
     * @param minClusterDistance minimum distance between clusters
     * @param minClusterSize minimum size of a cluster
     * @param minClusterIntensity minimum highest intensity of a cluster
     * @return list of clusters
     */

    @Nonnull
    public static List <List <Peak>> getRetTimeClusters(
            @Nonnull List <Peak> peaks,
            double minClusterDistance,
            int minClusterSize,
            double minClusterIntensity)
    {
        // Sort all peaks based on their retention times

        Collections.sort(peaks, new Comparator <Peak> () {
            @Override
            public int compare(Peak peak1, Peak peak2) {
                return Double.compare(peak1.getRetTime(), peak2.getRetTime());
            }
        });

        // -------------
        // Cluster peaks
        // -------------

        List <Double> retTimes = new ArrayList<>(peaks.size());
        for (Peak peak : peaks) retTimes.add(peak.getRetTime());

        List<List<Peak>> clusters = new ApacheDBSCANClusteringV2<Peak> ()
                .call(retTimes, peaks, minClusterDistance, minClusterSize, false);

        // ----------------------------------------
        // Filter clusters based on their intensity
        // ----------------------------------------

        List<List<Peak>> filteredClusters = new ArrayList <> ();

        for (final List<Peak> cluster : clusters)
        {
            double intensity = 0.0;
            for (Peak peak : cluster)
                if (peak.getIntensity() > intensity)
                    intensity = peak.getIntensity();

            if (intensity > minClusterIntensity) filteredClusters.add(cluster);
        }

        // --------------------------------------------
        // Sort clusters based on their retention times
        // --------------------------------------------

        Collections.sort(filteredClusters, new Comparator <List<Peak>>() {
            @Override
            public int compare(List<Peak> cluster1, List<Peak> cluster2) {
                double retTime1 = 0.0;
                for (Peak peak : cluster1) retTime1 += peak.getRetTime();
                retTime1 /= cluster1.size();

                double retTime2 = 0.0;
                for (Peak peak : cluster2) retTime2 += peak.getRetTime();
                retTime2 /= cluster2.size();

                return Double.compare(retTime1, retTime2);
            }
        });

        return filteredClusters;
    }
    
    /**
     * Cluster list of Peaks based on their chromatograms
     * 
     * @param peaks list of Peaks
     * @param threshold the maximum 'diameter' of a cluster
     * @return list of clusters
     */

    @Nonnull
    public static List <List <Peak>> getShapeClusters(
            @Nonnull List <Peak> peaks, double threshold)
    {
        List <NavigableMap <Double, Double>> chromatograms = new ArrayList <> ();
        List <Double> norms = new ArrayList <> ();
        List <Peak> objects = new ArrayList <> ();

        for (Peak peak : peaks) 
        {
            NavigableMap <Double, Double> chromatogram = peak.getChromatogram();
            double norm = peak.getNorm();

            chromatograms.add(chromatogram);
            norms.add(norm);
            objects.add(peak);
        }

        // Find distances between chromatograms
        int size = peaks.size();
        SymmetricMatrix distanceMatrix = new SymmetricMatrix(size);

        for (int i = 0; i < size; ++i) {
            distanceMatrix.set(i, i, Double.MAX_VALUE);
            for (int j = i + 1; j < size; ++j) {
                double z = Math.continuous_dot_product(chromatograms.get(i),chromatograms.get(j))
                        / norms.get(i) / norms.get(j);

                z = java.lang.Math.min(1.0, z);

                distanceMatrix.set(i, j, 180.0 * java.lang.Math.acos(z) / java.lang.Math.PI);
            }
        }

        // Cluster and return results
        return new Clustering<Peak>().hcluster(distanceMatrix, objects, threshold);
    }
    
    /**
     * Merge Peaks that have the same Shared Boundary in PeakInfo
     * 
     * @param peaks list of peaks
     * @param edgeToHeightRatio threshold of edge-to-height ratios
     * @param deltaToHeightRatio threshold for delta-to-height ratio
     * @return list of Peaks
     */
    
    private List <Peak> mergePeaks(final List <Peak> peaks, 
            double edgeToHeightRatio, double deltaToHeightRatio)
    {       
        List <Peak> result = new ArrayList <> (peaks.size());
        
        FeatureTools.correctPeakBoundaries(peaks,
                edgeToHeightRatio, deltaToHeightRatio);
        
        // -----------------------------------------------------------
        // Find all peaks with the same Shared Boundary. 
        // First, create a double map with entries in the form 
        // (m/z-value, leftSharedBoundary, rightSharedBoundary, list of PeakInfo)
        // -----------------------------------------------------------
        
        Map <Double, Map <Integer, Map <Integer, List <Peak>>>> outerMap = 
                new HashMap <> ();
        
        for (Peak peak : peaks)
        {
            double mz = peak.getMZ();
            int leftSharedBoundary = peak.getInfo().leftPeakIndex;
            int rightSharedBoundary = peak.getInfo().rightPeakIndex;

            Map<Integer, Map<Integer, List<Peak>>> middleMap = outerMap.get(mz);
            if (middleMap == null) {
                middleMap = new HashMap<>();
                outerMap.put(mz, middleMap);
            }

            Map<Integer, List<Peak>> innerMap = middleMap.get(leftSharedBoundary);
            if (innerMap == null) {
                innerMap = new HashMap<>();
                middleMap.put(leftSharedBoundary, innerMap);
            }

            List<Peak> innerMapValue = innerMap.get(rightSharedBoundary);
            if (innerMapValue == null) {
                innerMapValue = new ArrayList<>();
                innerMap.put(rightSharedBoundary, innerMapValue);
            }

            innerMapValue.add(peak);
        }
        
        // -----------------------------------------------
        // For each entry of the double map, create a peak
        // -----------------------------------------------
        
        for (Map <Integer, Map <Integer, List <Peak>>> middleMap 
                : outerMap.values())
            for (Map <Integer, List <Peak>> innerMap : middleMap.values())
                for (List <Peak> peakList : innerMap.values())
                {
                    NavigableMap <Double, Double> chromatogram =
                            new TreeMap <> ();

                    PeakInfo mergedInfo = new PeakInfo(peakList.get(0).getInfo());

                    for (Peak peak : peakList) {
                        chromatogram.putAll(peak.getChromatogram());
                        mergedInfo = PeakInfo.merge(mergedInfo, peak.getInfo());
                    }

                    result.add(new Peak(chromatogram, mergedInfo));
                }
        
        return result;
    }
    
    /**
     * Check if m/z value belongs to one of ranges
     *
     * @param mz m/z value to check
     * @param ranges list of ranges of m/z values
     * @return true if m/z values is in one of the ranges, otherwise false
     */
    
    private static boolean inRange(double mz, List <Range <Double>> ranges)
    {
        for (Range <Double> range : ranges)
            if (range.contains(mz)) return true;
        
        return false;
    }
}
