/*
 * Copyright (C) 2017 Du-Lab Team <dulab.binf@gmail.com>
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
 */

package dulab.adap.common.algorithms.optimization;

import dulab.adap.common.algorithms.statistics.BiGaussian;
import org.apache.commons.math3.analysis.ParametricUnivariateFunction;
import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
import org.apache.commons.math3.fitting.AbstractCurveFitter;
import org.apache.commons.math3.fitting.WeightedObservedPoint;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem;

import javax.annotation.Nonnull;
import java.util.Collection;

/**
 * @author Du-Lab Team <dulab.binf@gmail.com>
 */
public class BiGaussianCurveFitter extends AbstractCurveFitter
{
    /** Implementation of {@code ParametricUnivariateFunction} that
     * evaluates value and gradient of the bi-gaussian.
     */
    private static class Parametric implements ParametricUnivariateFunction
    {
        /**
         * Computes the value of the BiGaussian at {@code x}.
         *
         * @param x Value for which the function must be computed.
         * @param param Values of height, mean, left and right standard deviations.
         * @return the value of the function.
         * @throws DimensionMismatchException if the size of {@code param} if not 4.
         * @throws NotStrictlyPositiveException if {@code param[2]} or {@code param[3]} is negative.
         */
        @Override
        public double value(double x, double ... param)
                throws DimensionMismatchException, NotStrictlyPositiveException
        {
            validateParameters(param);

            final double height = param[0];
            final double diff = x - param[1];
            final double std = diff > 0.0 ? param[3] : param[2];

            return height * Math.exp(- diff * diff / (2 * std * std));
        }

        /**
         * Computes the value of the gradient at {@code x}.
         * @param x Value for which the function must be computed.
         * @param param Values of height, mean, left and right standard deviations.
         * @return the gradient vector at {@code x}.
         * @throws DimensionMismatchException if the size of {@code param} if not 4.
         * @throws NotStrictlyPositiveException if {@code param[2]} or {@code param[3]} is negative.
         */
        @Override
        public double[] gradient(double x, double ... param)
                throws DimensionMismatchException, NotStrictlyPositiveException
        {
            validateParameters(param);

            final double height = param[0];
            final double diff = x - param[1];
            final boolean isRight = diff > 0.0;
            final double std = isRight ? param[3] : param[2];
            final double std2 = std * std;

            final double dh = Math.exp(- diff * diff / (2 * std2));
            final double dm = dh * height * diff / (std2);
            final double dsLeft = isRight ? 0.0 : dh * height * diff * diff / (std2 * std);
            final double dsRight = isRight ? dh * height * diff * diff / (std2 * std) : 0.0;

            return new double[] {dh, dm, dsLeft, dsRight};
        }

        /**
         * Validates parameters to ensure they are appropriate for the evaluation of
         * the {@link #value(double,double[])} and {@link #gradient(double,double[])}
         * methods.
         * @param param Values of height, mean, left and right standard deviations.
         * @throws DimensionMismatchException if the size of {@code param} if not 4.
         * @throws NotStrictlyPositiveException if {@code param[2]} or {@code param[3]} is negative.
         */
        private void validateParameters(@Nonnull double[] param)
                throws DimensionMismatchException, NotStrictlyPositiveException
        {
            if (param.length != 4)
                throw new DimensionMismatchException(param.length, 4);

//            if (param[2] <= 0.0)
//                throw new NotStrictlyPositiveException(param[2]);
//
//            if (param[3] <= 0.0)
//                throw new NotStrictlyPositiveException(param[3]);
        }
    }

    /** Parametric function to be fitted. */
    private static final Parametric FUNCTION = new Parametric();

    /** Initial guess. */
    private final double[] initialGuess;

    /** Maximum number of iterations of the optimization algorithm. */
    private final int maxIter;

    /**
     * Constructor used by the factory methods.
     *
     * @param initialGuess Initial guess.
     * @param maxIter Maximum number of iterations of the optimization algorithm.
     */
    public BiGaussianCurveFitter(@Nonnull double[] initialGuess, int maxIter) {
        this.initialGuess = initialGuess;
        this.maxIter = maxIter;
    }

    /** {@inheritDoc} */
    @Override
    protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> observations)
    {
        // Prepare least-square problem.
        final int len = observations.size();
        final double[] target = new double[len];

        int i = 0;
        for (WeightedObservedPoint obs : observations)
            target[i++] = obs.getY();

        final AbstractCurveFitter.TheoreticalValuesFunction model =
                new AbstractCurveFitter.TheoreticalValuesFunction(FUNCTION, observations);

        return new LeastSquaresBuilder()
                .maxEvaluations(Integer.MAX_VALUE)
                .maxIterations(maxIter)
                .start(initialGuess)
                .target(target)
                .model(model.getModelFunction(), model.getModelFunctionJacobian())
                .build();
    }
}
