/*
 * Copyright (C) 2017 Du-Lab Team <dulab.binf@gmail.com>
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
 */

package dulab.adap.workflow.decomposition;

import dulab.adap.datamodel.*;
import org.dulab.javanmf.algorithms.*;
import org.dulab.javanmf.updaterules.MUpdateRule;
import org.dulab.javanmf.updaterules.UpdateRule;
import org.ejml.data.DMatrixRMaj;

import javax.annotation.Nonnull;
import java.util.*;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/**
 * This class finds components (bi-gaussians) that best describe given peaks.
 * <p>
 * Step 1. Adjust the apex retention time of each peak by fitting a parabola to the top points of the peak
 * and calculating the center of that parabola.
 * <p>
 * Step 2. Cluster the adjusted apex retention times, using the hierarchical clustering with the ward-linkage.
 * <p>
 * Step 3. For each cluster, construct a bi-gaussian as an "average" of all peaks in a cluster.
 * <p>
 * Step 4. Adjust standard deviations of the bi-gaussians by minimizing the cost function ||X - W(s) x H||,
 * where X represents the detected peaks, W represents bi-gaussians with standard deviations s,
 * H represents the decomposition coefficients. The optimization is performed by alternating minimization
 * of two problems:
 * 4.1. Non-linear least squares problem argmin ||X - W(s) x H||^2 with respect to s,
 * 4.2. Linear least squares problem argmin ||X - W(s) x H||^2 with respect to H.
 * <p>
 * Step 5. Decompose every chromatogram into a linear combination of constructed components (bi-gaussians).
 *
 * @author Du-Lab Team <dulab.binf@gmail.com>
 */
public class ComponentSelector {

    private static final Logger LOG = Logger.getLogger(ComponentSelector.class.getName());

    private boolean cancel;

    private final AlternatingLeastSquaresMatrixFactorization factorization;

    /**
     * Creates an instance of the class for finding component that best fit given real peaks.
     */
    public ComponentSelector() {
        this.cancel = false;

        factorization = new AlternatingLeastSquaresMatrixFactorization(new ComponentConstraint(), null, 1e-12, 40000);
    }

    public void setCancel(boolean cancel) {
        this.cancel = cancel;
    }

    /**
     * Finds components, their elution profiles and spectra.
     * <p>
     * <ol>
     * <li>Use retention-time clustering to determine the number of componets and their approximate elution profiles</li>
     * <li>Perform the non-negative matrix factroization to construct elution profiles of the components</li>
     * <li>Perform decomposition of all chromatograms into the linear combination of components and the baseline</li>
     * <li>Create a list of {@link BetterComponent} for each component with its elution profile and spectrum</li>
     * </ol>
     *
     * @return list of {@link BetterComponent} containing components with their elution profiles and spectra
     */
    @Nonnull
    public List<BetterComponent> execute(List<BetterPeak> chromatograms,
                                         RetTimeClusterer.Cluster cluster,
                                         double retTimeTolerance,
                                         boolean adjustApexRetTime,
                                         int minClusterSize) {

        this.cancel = false;

        // BiGaussianDetector performs adjustment of the apex retention times, cluster those apex retention times
        // and construct a bi-gaussian for each cluster
        BiGaussianDetector detector = new BiGaussianDetector(chromatograms, cluster,
                retTimeTolerance, adjustApexRetTime, minClusterSize);

        // Sorted array of all retention times
        double[] retTimes = detector.retTimes.stream()
                .mapToDouble(Double::doubleValue)
                .toArray();

        OptimizationResult nmfResult = performNonNegativeMatrixFactorization(
                detector.peaks, detector.biGaussians, retTimes);

        if (nmfResult == null)
            return new ArrayList<>(0);

        // --------------------------------------------------
        // ----- Perform decomposition of chromatograms -----
        // --------------------------------------------------

        double startRetTime = cluster.start;
        double endRetTime = cluster.end;
        double[] allRetTimes = getRetTimes(chromatograms, startRetTime, endRetTime);

        List<BetterPeak> filteredChromatograms = chromatograms.stream()
                .filter(c -> IntStream.range(0, c.chromatogram.length)
                        .filter(i -> startRetTime <= c.chromatogram.xs[i])
                        .filter(i -> c.chromatogram.xs[i] <= endRetTime)
                        .mapToDouble(i -> c.chromatogram.ys[i])
                        .sum() > 0.0)
                .collect(Collectors.toList());

        OptimizationResult lsrResult = performLeastSquaresRegression(
                allRetTimes, filteredChromatograms, nmfResult);

        return createComponents(allRetTimes, filteredChromatograms, lsrResult);
    }

    /**
     * Performs Non-negative Matrix Factorization. Uses Bi-Gaussians as initial components.
     *
     * @param peaks       list of peaks
     * @param biGaussians list of bi-gaussians
     * @param retTimes    retention times of bi-gaussians
     * @return pair (ComponentMatrix, Coefficient Matrix)
     */
    private OptimizationResult performNonNegativeMatrixFactorization(List<Peak> peaks, List<BiGaussian> biGaussians, double[] retTimes) {

        if (biGaussians.isEmpty())
            return null;

        // MyMatrix representing detected peaks
        MyMatrix peakMatrix = peaksToMatrix(peaks, false);

        // Scale
        peakMatrix.divideColumnsBySquareRootMax();

        // Construct the matrix representing the components (bi-gaussians)
        MyMatrix componentMatrix = biGaussiansToMatrix(biGaussians, retTimes);
        MyMatrix extendedComponentMatrix = componentMatrix.appendColumn(1.0);

        MyMatrix coefficientMatrix = getCoefficientMatrix(peakMatrix, componentMatrix);
        MyMatrix extendedCoefficientMatrix = coefficientMatrix.appendRow(peakMatrix.findMaximum() * 1e-12);

        runMatrixFactorization(peakMatrix, extendedComponentMatrix, extendedCoefficientMatrix);

        componentMatrix.set(extendedComponentMatrix);
        coefficientMatrix.set(extendedCoefficientMatrix);

        return new OptimizationResult(componentMatrix, coefficientMatrix);
    }

    /**
     * Performs Least-Square decomposition of all chromatograms. Uses the components constructed by NMF
     *
     * @param allRetTimes   retention times
     * @param chromatograms chromatograms
     * @param nmfResult     result of non-negative matrix factorization
     * @return pair (ModifiedComponentMatrix, Spectra)
     */
    private OptimizationResult performLeastSquaresRegression(
            double[] allRetTimes, List<BetterPeak> chromatograms, OptimizationResult nmfResult) {

        PeakInfo[] chromatogramsInfo = chromatograms.stream()
                .map(p -> p.info)
                .toArray(PeakInfo[]::new);

        // Construct the matrix representing all chromatograms
        MyMatrix peakMatrix = peaksToMatrix(new PeakList(chromatograms, allRetTimes, false));

        // Remove components identically equal to zero
        MyMatrix componentMatrix = nmfResult.componentMatrix.removeZeroColumns();

        // Scale components
        componentMatrix.divideColumnsByMax();
        MyMatrix spectraMatrix = getCoefficientMatrix(peakMatrix, componentMatrix);

        // Add the baseline to components
        MyMatrix componentsWithBaseline = componentMatrix.appendColumn(1.0);
        MyMatrix limit = spectraMatrix.appendRow(Double.MAX_VALUE);
//        MyMatrix spectra = spectraMatrix.appendRow(spectraMatrix.findMaximum() * 1e-12);
        MyMatrix spectra = new MyMatrix(new double[spectraMatrix.getNumRows() + 1][spectraMatrix.getNumCols()]);

        runMatrixRegression(peakMatrix, componentsWithBaseline, spectra, limit);

        return new OptimizationResult(componentMatrix, spectra, chromatogramsInfo);
    }

    /**
     * Creates a list of {@link BetterComponent}
     *
     * @param allRetTimes           retention times
     * @param filteredChromatograms chromatograms
     * @param lsResult              results of the least-square decomposition of chromatograms
     * @return list of {@link BetterComponent}
     */
    private List<BetterComponent> createComponents(
            double[] allRetTimes, List<BetterPeak> filteredChromatograms, OptimizationResult lsResult) {

        int numComponents = lsResult.componentMatrix.getNumCols();

        PeakInfo[] chromatogramsInfo = filteredChromatograms.stream()
                .map(c -> c.info)
                .toArray(PeakInfo[]::new);

        // Get all mzValues
        double[] mzValues = filteredChromatograms.stream()
                .mapToDouble(BetterPeak::getMZ)
                .toArray();

        // Create a list of {@link BetterComponent}
        List<BetterComponent> componentsList = new ArrayList<>(numComponents);
        for (int j = 0; j < numComponents; ++j) {

            double maxHeight = Math.max(0.0, lsResult.coefficientMatrix.findRowMaximum(j));

            final int finalJ = j;
            double[] componentsRetTimes = IntStream.range(0, lsResult.componentMatrix.getNumRows())
                    .filter(i -> lsResult.componentMatrix.get(i, finalJ) > 0.0)
                    .mapToDouble(i -> allRetTimes[i])
                    .toArray();

            double[] componentsIntensities = IntStream.range(0, lsResult.componentMatrix.getNumRows())
                    .mapToDouble(i -> lsResult.componentMatrix.get(i, finalJ))
                    .filter(intensity -> intensity > 0.0)
                    .toArray();

            Chromatogram chromatogram = new Chromatogram(componentsRetTimes, componentsIntensities);
            chromatogram.scale(maxHeight / lsResult.componentMatrix.findColumnMaximum(j));
            double[] spectrum = lsResult.coefficientMatrix.getRow(j);

            componentsList.add(new BetterComponent(j, chromatogram,
                    new Spectrum(mzValues, spectrum, chromatogramsInfo),
                    filteredChromatograms.get(lsResult.coefficientMatrix.findIndexOfRowMaximum(j)).info));
        }

        return componentsList;
    }

    private MyMatrix peaksToMatrix(List<Peak> peaks) {
        return peaksToMatrix(peaks, false);
    }

    private MyMatrix peaksToMatrix(List<Peak> peaks, boolean keepBoundaries) {
        double[][] peakArray = new double[peaks.size()][];
        for (int i = 0; i < peaks.size(); ++i) {
            Peak peak = peaks.get(i);
            peakArray[i] = new double[peak.chromatogram.length];
            int startIndex = 0;
            int endIndex = peak.chromatogram.length - 1;
            if (keepBoundaries) {
                startIndex = Math.max(peak.startIndex, startIndex);
                endIndex = Math.min(peak.endIndex, endIndex);
            }
            for (int j = startIndex; j <= endIndex; ++j) {
                peakArray[i][j] = peak.chromatogram.getIntensity(j);
            }
        }
        return new MyMatrix(peakArray).transpose();
    }

    /**
     * Calculates the initial coefficient matrix
     * <p>
     * For each component, the retention time (scan) of the its maximum is calculated. Then, the upper limit is set
     * to the intensity of each peak at that retention time.
     *
     * @param peakMatrix      peak matrix
     * @param componentMatrix component matrix
     * @return upper limits of the coefficient matrix
     */
    private MyMatrix getCoefficientMatrix(MyMatrix peakMatrix, MyMatrix componentMatrix) {
        MyMatrix coefficientLimits = new MyMatrix(componentMatrix.getNumCols(), peakMatrix.getNumCols());
        for (int i = 0; i < componentMatrix.getNumCols(); ++i) {
            int index = componentMatrix.findIndexOfColumnMaximum(i);
            for (int j = 0; j < peakMatrix.getNumCols(); ++j) {
                coefficientLimits.set(i, j, peakMatrix.get(index, j));
            }
        }
        return coefficientLimits;
    }

    private MyMatrix biGaussiansToMatrix(List<BiGaussian> biGaussians, double[] retTimes) {
        return new MyMatrix(
                biGaussians.stream()
                        .map(b -> BiGaussian.evaluate(retTimes, b))
                        .toArray(double[][]::new))
                .transpose();
    }

    private double[] getRetTimes(List<BetterPeak> peaks, double startRetTime, double endRetTime) {

        Set<Double> retTimeSet = new HashSet<>();
        for (BetterPeak peak : peaks)
            for (double retTime : peak.chromatogram.xs)
                if (startRetTime <= retTime && retTime <= endRetTime)
                    retTimeSet.add(retTime);

        List<Double> retTimeList = new ArrayList<>(retTimeSet);
        retTimeList.sort(Comparator.naturalOrder());

        return retTimeList.stream().mapToDouble(Double::doubleValue).toArray();
    }

    /**
     * Solves an optimization problem to find W and H such that X = W x H
     *
     * @param x matrix X
     * @param w matrix W
     * @param h matrix H
     */
    private void runMatrixFactorization(MyMatrix x, MyMatrix w, MyMatrix h) {

        DMatrixRMaj matrixX = new DMatrixRMaj(x.getData());
        DMatrixRMaj matrixW = new DMatrixRMaj(w.getData());
        DMatrixRMaj matrixH = new DMatrixRMaj(h.getData());
        factorization.solve(matrixX, matrixW, matrixH, true);

        for (int i = 0; i < w.getNumRows(); ++i)
            for (int j = 0; j < w.getNumCols(); ++j)
                w.set(i, j, matrixW.get(i, j));

        for (int i = 0; i < h.getNumRows(); ++i)
            for (int j = 0; j < h.getNumCols(); ++j)
                h.set(i, j, matrixH.get(i, j));
    }

    /**
     * Solves an optimization problem to find H such that X = W x H
     *
     * @param x matrix X
     * @param w matrix W
     * @param h matrix H
     */
    private void runMatrixRegression(MyMatrix x, MyMatrix w, MyMatrix h, MyMatrix limit) {

        DMatrixRMaj matrixX = new DMatrixRMaj(x.getData());
        DMatrixRMaj matrixW = new DMatrixRMaj(w.getData());
        DMatrixRMaj matrixH = new DMatrixRMaj(h.getData());
        DMatrixRMaj matrixLimit = new DMatrixRMaj(limit.getData());

        NonNegativeLeastSquares nonNegativeLeastSquares = new NonNegativeLeastSquares();

        nonNegativeLeastSquares.solve(matrixX, matrixW, matrixH);

        for (int i = 0; i < matrixH.numRows; ++i)
            for (int j = 0; j < matrixH.numCols; ++j) {
                if (matrixH.unsafe_get(i, j) > matrixLimit.unsafe_get(i, j))
                    matrixH.unsafe_set(i, j, matrixLimit.unsafe_get(i, j));
            }

        for (int i = 0; i < w.getNumRows(); ++i)
            for (int j = 0; j < w.getNumCols(); ++j)
                w.set(i, j, matrixW.get(i, j));

        for (int i = 0; i < h.getNumRows(); ++i)
            for (int j = 0; j < h.getNumCols(); ++j)
                h.set(i, j, matrixH.get(i, j));
    }

    private static class OptimizationResult {

        public MyMatrix componentMatrix;
        public MyMatrix coefficientMatrix;
        public PeakInfo[] chromatogramsInfo;

        public OptimizationResult(MyMatrix componentMatrix, MyMatrix coefficientMatrix, PeakInfo[] chromatogramsInfo) {
            this.componentMatrix = componentMatrix;
            this.coefficientMatrix = coefficientMatrix;
            this.chromatogramsInfo = chromatogramsInfo;
        }

        public OptimizationResult(MyMatrix componentMatrix, MyMatrix coefficientMatrix) {
            this(componentMatrix, coefficientMatrix, null);
        }
    }
}
