/* 
 * Copyright (C) 2016 Du-Lab Team <dulab.binf@gmail.com>
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
 */
package dulab.adap.common.algorithms.machineleanring;

import cern.colt.function.DoubleFunction;
import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix1D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import cern.colt.matrix.linalg.Algebra;
import com.joptimizer.functions.ConvexMultivariateRealFunction;
import com.joptimizer.functions.LinearMultivariateRealFunction;
import com.joptimizer.functions.PDQuadraticMultivariateRealFunction;
import com.joptimizer.optimizers.JOptimizer;
import com.joptimizer.optimizers.OptimizationRequest;


import dulab.adap.common.algorithms.Math;
import dulab.adap.common.types.MutableDouble;

import java.util.*;

import org.apache.commons.math3.exception.OutOfRangeException;
import org.apache.commons.math3.exception.TooManyIterationsException;

/**
 *
 * @author aleksandrsmirnov
 */
public class Optimization
{
    public class OptimizationResult
    {
        public final boolean successful;
        public final double[] coefficients;
        public final double error;
        public final String message;

        public OptimizationResult(boolean successful, double[] coefficients, double error, String message) {
            this.successful = successful;
            this.coefficients = coefficients;
            this.error = error;
            this.message = message;
        }
    }

    public OptimizationResult decompose(
            final NavigableMap <Double, Double> signal,
            final List <NavigableMap <Double, Double>> referenceSignals,
            final double[] initValues,
            final double tolerance,
            final int maxIteration)
    {
        // Find the union of x-values
        SortedSet <Double> xValues = new TreeSet <> (signal.navigableKeySet());
        for (final NavigableMap <Double, Double> s : referenceSignals)
            xValues.addAll(new TreeSet <> (s.navigableKeySet()));
        
        final int n = xValues.size();
        final int m = referenceSignals.size();
        
        // Interpolate values so that each signal would have the same number of points
//        double[][] interpSignal = new double[1][];
//        interpSignal[0] = Math.interpolate(xValues, signal);
        double[] interpSignal = Math.interpolate(xValues, signal);
        double[][] interpReferenceSignals = new double[m][n];
        for (int i = 0; i < m; ++i)
            interpReferenceSignals[i] = Math.interpolate(xValues, referenceSignals.get(i));
        
        // --------------------------------------------------------------------
        // ----- Perform optimization using JOptimizer and COLT libraries -----
        // --------------------------------------------------------------------
        
        Algebra alg = new Algebra();
        
        // ------------------------------------------------------------------
        // Minimize the function f(alpha) = 1/(2n) * || S * alpha - S0 || ^ 2
        //
        // where S is n-by-m matrix consisting interpReferenceSignals
        //       S0 is vector consisting of interpSignal
        //       alpha is vector consisting of non-negative fitValues
        // ------------------------------------------------------------------
        
        // Vector S0
        DoubleMatrix1D vectorS0 = new DenseDoubleMatrix1D(interpSignal);
        
        // MyMatrix S
        DoubleMatrix2D matrixS = new DenseDoubleMatrix2D(n, m);
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < m; ++j)
                matrixS.set(i, j, interpReferenceSignals[j][i]);
        }
        
        // MyMatrix ST = S^T
        DoubleMatrix2D matrixST = alg.transpose(matrixS);
        
        // ------------------------------------------------------
        // Represent the function f in the form
        //
        // f(alpha) = 1/2 * alpha^T * P * alpha + Q^T * alpha + R
        //
        // where P = 1/n * S^T * S
        //       Q = - 1/n * S^T * S0
        //       R = 1/(2n) * S0^T * S
        // ------------------------------------------------------
        
        
        // MyMatrix P = 1/n * S^T * S
        DoubleMatrix2D matrixP = alg.mult(matrixST, matrixS).assign(new DoubleFunction() {
            @Override
            public double apply(double x) {return x / n;}
        });
        
        // Vector Q = - 1/n * S^T * S0
        DoubleMatrix1D vectorQ = alg.mult(matrixST, vectorS0).assign(new DoubleFunction() {
            @Override
            public double apply(double x) {return -x / n;}
        });
        
        // Scalar R = 1/(2n) * S0^T * S
        double scalarR = alg.mult(vectorS0, vectorS0) / n / 2;
        
        // -----------------------------------
        // Create JOptimize Objective Function
        // -----------------------------------
        
        PDQuadraticMultivariateRealFunction objective = 
                new PDQuadraticMultivariateRealFunction(
                        matrixP.toArray(), vectorQ.toArray(), scalarR);
        
        // --------------------------------------------
        // Create JOptimize Inequalities for alpha >= 0
        // --------------------------------------------
        
        ConvexMultivariateRealFunction[] inequalities = 
                new ConvexMultivariateRealFunction[m];
        for (int i = 0; i < m; ++i) {
            double[] v = new double[m]; v[i] = -1;
            inequalities[i] = new LinearMultivariateRealFunction(v, 0);
        }
        
        // ----------------------------
        // Set up JOptimization request
        // ----------------------------
        
        OptimizationRequest optim = new OptimizationRequest();
        optim.setF0(objective);
        optim.setFi(inequalities);
        optim.setInitialPoint(initValues);
        optim.setTolerance(tolerance);
        optim.setToleranceInnerStep(tolerance);
        optim.setToleranceFeas(tolerance);
        optim.setToleranceKKT(tolerance);
        optim.setMaxIteration(maxIteration);
        
        // --------------------
        // Perform Optimization
        // --------------------
        
        JOptimizer opt = new JOptimizer();
        opt.setOptimizationRequest(optim);

        try {
            opt.optimize();
        } catch (Exception e) {
            return new OptimizationResult(false, initValues, Double.MAX_VALUE,
                    "Error during decomposition: " + e.getMessage());
        }
        
        if (opt.getOptimizationResponse().getReturnCode() != 0)
            return new OptimizationResult(false, initValues, Double.MAX_VALUE,
                    "Decomposition returned non-zero code: " );

        return new OptimizationResult(true,
                opt.getOptimizationResponse().getSolution(),
                objective.value(opt.getOptimizationResponse().getSolution()),
                "Successful decomposition");
    }
    
//    /**
//     *
//     * Finds linear combination of referenceSignals to fit signal
//     *
//     * @param signal
//     * @param referenceSignals
//     * @param initValues
//     * @param params
//     * @return
//     */
//
//    static public List <Double> fitSignal (
//            final NavigableMap <Double, Double> signal,
//            final List <NavigableMap <Double, Double>> referenceSignals,
//            final List <Double> initValues,
//            final OptimizationParameters params)
//    {
//        // Find the union of x-values
//        SortedSet <Double> xValues = new TreeSet <> (signal.navigableKeySet());
//        for (final NavigableMap <Double, Double> s : referenceSignals)
//            xValues.addAll(new TreeSet <> (s.navigableKeySet()));
//
//        final int n = xValues.size();
//        final int m = referenceSignals.size();
//
//        // double[] fitValues = initValues.stream().mapToDouble(x -> x).toArray();
//        double[] fitValues = new double[initValues.size()];
//        for (int i = 0; i < initValues.size(); ++i) fitValues[i] = initValues.get(i);
//
//        double[] interpSignal = Math.interpolate(xValues, signal);
//        double[][] interpReferenceSignals = new double[m][n];
//        for (int i = 0; i < m; ++i)
//            interpReferenceSignals[i] =
//                    Math.interpolate(xValues, referenceSignals.get(i));
//
//        int count = 0;
//        double norm, prev_cost = Double.MAX_VALUE;
//        double alpha = params.alpha;
//
//        if (params.verbose)
//            System.out.println("----- Optimization -----");
//
//        do {
//            double[] gradient = new double[m];
//            double cost = 0.0;
//
//            for (int i = 0; i < n; ++i) {
//                double zi = interpSignal[i];
//
//                for (int j = 0; j < m; ++j)
//                    zi -= fitValues[j] * interpReferenceSignals[j][i];
//
//                cost += zi * zi;
//
//                for (int j = 0; j < m; ++j)
//                    gradient[j] += zi * interpReferenceSignals[j][i];
//            }
//
//            cost = 0.5 * cost / n;
//
//            if (java.lang.Math.abs(cost - prev_cost) < params.costTolerance)
//                break;
//
//            if (cost > prev_cost) alpha /= 2;
//
//            // Update fitValues and calculate the norm of the gradient
//
//            norm = 0.0;
//            for (int j = 0; j < m; ++j) {
//                double dj = gradient[j] / n;
//                fitValues[j] += alpha * dj;
//
//                if (fitValues[j] < 0.0) fitValues[j] = 0.0;
//
//                norm += dj * dj;
//            }
//            norm = java.lang.Math.sqrt(norm) * alpha;
//
//            if (params.verbose)
//                System.out.println("N=" + count + "; Cost=" + cost +
//                        "; GradNorm=" + norm + "; Alpha=" + alpha);
//
//            if (++count > params.maxIterationCount) {
//                System.out.println("WARNING: Gradient Descent: maximum number "
//                        + "of iterations is reached with no convergence");
//
//                // Assign the initial values
//                // fitValues = initValues.stream().mapToDouble(x -> x).toArray();
//                for (int i = 0; i < initValues.size(); ++i) fitValues[i] = initValues.get(i);
//                break;
//            }
//
//            prev_cost = cost;
//
//        } while (norm > params.gradientTolerance);
//
//        if (params.verbose) System.out.println();
//
//        // if (prev_cost < 0.7) return initValues;
//
//        List <Double> resultValues = new ArrayList <> (fitValues.length);
//        for (double v : fitValues) resultValues.add(v);
//
//        return resultValues;
//
//        // return Arrays.stream(fitValues).boxed().collect(Collectors.toList());
//    }
    
    static public double alignSignals(
            final NavigableMap <Double, Double> f1,
            final NavigableMap <Double, Double> f2,
            MutableDouble optimalShift,
            final double maxShift,
            final OptimizationParameters params)
            
            throws OutOfRangeException, TooManyIterationsException
    {
        // Find derivative of f2
        NavigableMap <Double, Double> df2 = Math.differentiate(f2);
        if (df2.isEmpty()) return 0.0;
        
        final boolean verbose = params.verbose;
        final double alpha = params.alpha;
        final int maxIterationCount = params.maxIterationCount;
        final double eps = params.gradientTolerance;
        
        if (verbose)
            System.out.println("----- Optimization -----");
        
        // Steepest Descent Algorithm
        double shift = optimalShift.get(), ds;
        int count = 0;
        
        do {
            ds = alpha * Math.convolution(f1, df2, shift);
            shift -= ds;
            
            if (verbose) {
                double c = Math.convolution(f1, f2, shift);
                System.out.println("N=" + count + "; Cost=" + c + 
                        "; GradNorm=" + ds + "; Alpha=" + alpha);
            }
            
            if (java.lang.Math.abs(shift) > maxShift)
                throw new OutOfRangeException(shift, -maxShift, maxShift);
            
            if (++count > maxIterationCount)
                throw new TooManyIterationsException(maxIterationCount);
            
        } while (java.lang.Math.abs(ds) > eps);
        
        optimalShift.set(shift);
        
        return Math.convolution(f1, f2, shift);
    }
}
