/*
 * Decompiled with CFR 0.152.
 */
package ml.shifu.guagua.mapreduce.example.nn;

import ml.shifu.guagua.mapreduce.example.nn.NNUtils;

public class Weight {
    private static final double ZERO_TOLERANCE = 1.0E-17;
    private double learningRate;
    private String algorithm;
    private double decay = 1.0E-4;
    private double[] lastDelta = null;
    private double[] lastGradient = null;
    private double outputEpsilon = 0.35;
    private double eps = 0.0;
    private double shrink = 0.0;
    private double momentum = 0.0;
    private double[] updateValues = null;
    private static final double DEFAULT_INITIAL_UPDATE = 0.1;
    private static final double DEFAULT_MAX_STEP = 50.0;

    public Weight(int numWeight, double numTrainSize, double rate, String algorithm) {
        this.lastDelta = new double[numWeight];
        this.lastGradient = new double[numWeight];
        this.eps = this.outputEpsilon / numTrainSize;
        this.shrink = rate / (1.0 + rate);
        this.learningRate = rate;
        this.algorithm = algorithm;
        this.updateValues = new double[numWeight];
        for (int i = 0; i < this.updateValues.length; ++i) {
            this.updateValues[i] = 0.1;
            this.lastDelta[i] = 0.0;
        }
    }

    public double[] calculateWeights(double[] weights, double[] gradients) {
        for (int i = 0; i < gradients.length; ++i) {
            int n = i;
            weights[n] = weights[n] + this.updateWeight(i, weights, gradients);
        }
        return weights;
    }

    private double updateWeight(int index, double[] weights, double[] gradients) {
        if (this.algorithm.equalsIgnoreCase("B")) {
            return this.updateWeightBP(index, weights, gradients);
        }
        if (this.algorithm.equalsIgnoreCase("Q")) {
            return this.updateWeightQBP(index, weights, gradients);
        }
        if (this.algorithm.equalsIgnoreCase("M")) {
            return this.updateWeightMHP(index, weights, gradients);
        }
        if (this.algorithm.equalsIgnoreCase("S")) {
            return this.updateWeightSCG(index, weights, gradients);
        }
        if (this.algorithm.equalsIgnoreCase("R")) {
            return this.updateWeightRLP(index, weights, gradients);
        }
        return 0.0;
    }

    private double updateWeightBP(int index, double[] weights, double[] gradients) {
        double delta;
        this.lastDelta[index] = delta = gradients[index] * this.learningRate + this.lastDelta[index] * this.momentum;
        return delta;
    }

    private double updateWeightQBP(int index, double[] weights, double[] gradients) {
        double w = weights[index];
        double d = this.lastDelta[index];
        double s = -gradients[index] + this.decay * w;
        double p = -this.lastGradient[index];
        double nextStep = 0.0;
        if (d < 0.0) {
            if (s > 0.0) {
                nextStep -= this.eps * s;
            }
            nextStep = s >= this.shrink * p ? (nextStep += this.learningRate * d) : (nextStep += d * s / (p - s));
        } else if (d > 0.0) {
            if (s < 0.0) {
                nextStep -= this.eps * s;
            }
            nextStep = s <= this.shrink * p ? (nextStep += this.learningRate * d) : (nextStep += d * s / (p - s));
        } else {
            nextStep -= this.eps * s;
        }
        this.lastDelta[index] = nextStep;
        this.lastGradient[index] = gradients[index];
        return nextStep;
    }

    private double updateWeightMHP(int index, double[] weights, double[] gradients) {
        if (Math.abs(gradients[index]) < 1.0E-17) {
            return 0.0;
        }
        if (gradients[index] > 0.0) {
            return this.learningRate;
        }
        return -this.learningRate;
    }

    private double updateWeightSCG(int index, double[] weights, double[] gradients) {
        return 0.0;
    }

    private double updateWeightRLP(int index, double[] weights, double[] gradients) {
        int change = NNUtils.sign(gradients[index] * this.lastGradient[index]);
        double weightChange = 0.0;
        if (change > 0) {
            double delta = this.updateValues[index] * 1.2;
            delta = Math.min(delta, 50.0);
            weightChange = (double)NNUtils.sign(gradients[index]) * delta;
            this.updateValues[index] = delta;
            this.lastGradient[index] = gradients[index];
        } else if (change < 0) {
            double delta = this.updateValues[index] * 0.5;
            this.updateValues[index] = delta = Math.max(delta, 1.0E-6);
            weightChange = -this.lastDelta[index];
            this.lastGradient[index] = 0.0;
        } else if (change == 0) {
            double delta = this.updateValues[index];
            weightChange = (double)NNUtils.sign(gradients[index]) * delta;
            this.lastGradient[index] = gradients[index];
        }
        this.lastDelta[index] = weightChange;
        return weightChange;
    }
}

