package dulab.adap.common.algorithms.statistics;

import org.jblas.DoubleMatrix;
import org.jblas.MatrixFunctions;

import javax.annotation.Nonnull;

/**
 * This class calculated the softmax of a vector
 *
 * @author Du-Lab Team <dulab.binf@gmail.com>
 */
public class SoftMax {

    /** Scaling parameter  */
    private final double gamma;

    /**
     * Creates an instance of SoftMax with scaling parameter gamma = 1.0
     */
    public SoftMax() {
        gamma = 1.0;
    }

    /**
     * Creates an instance of SoftMax with scaling parameter gamma
     *
     * @param gamma positive real number
     */
    public SoftMax(double gamma) {
//        if (gamma <= 0.0)
//            throw new IllegalArgumentException("Parameter gamma has to be positive");

        this.gamma = gamma;
    }

    /**
     * Evaluates the softmax of a vector:
     *   softmax = exp(-gamma * vector) / sum(exp(-gamma * vector))
     *
     * @param vector vector of values
     * @return softmax of the vector
     */

    @Nonnull
    public DoubleMatrix evaluate(@Nonnull DoubleMatrix vector) {
        DoubleMatrix expVector = MatrixFunctions.exp(vector.mul(-gamma));
        return expVector.div(expVector.sum());
    }
}
