/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.flat.train.prop;

import org.encog.ml.data.MLDataSet;
import org.encog.neural.flat.FlatNetwork;
import org.encog.neural.flat.train.prop.TrainFlatNetworkProp;

public class TrainFlatNetworkQPROP
extends TrainFlatNetworkProp {
    private double learningRate;
    private double[] lastDelta;
    private double decay = 1.0E-4;
    private double eps;
    private double outputEpsilon = 0.35;
    private double shrink;

    public TrainFlatNetworkQPROP(FlatNetwork network, MLDataSet training, double theLearningRate) {
        super(network, training);
        this.learningRate = theLearningRate;
        this.lastDelta = new double[this.network.getWeights().length];
    }

    @Override
    public void initOthers() {
        this.eps = this.outputEpsilon / (double)this.getTraining().getRecordCount();
        this.shrink = this.learningRate / (1.0 + this.learningRate);
    }

    public final double getLearningRate() {
        return this.learningRate;
    }

    public final void setLearningRate(double rate) {
        this.learningRate = rate;
    }

    @Override
    public final double updateWeight(double[] gradients, double[] lastGradient, int index) {
        double w = this.network.getWeights()[index];
        double d = this.lastDelta[index];
        double s = -this.gradients[index] + this.decay * w;
        double p = -lastGradient[index];
        double nextStep = 0.0;
        if (d < 0.0) {
            if (s > 0.0) {
                nextStep -= this.eps * s;
            }
            nextStep = s >= this.shrink * p ? (nextStep += this.learningRate * d) : (nextStep += d * s / (p - s));
        } else if (d > 0.0) {
            if (s < 0.0) {
                nextStep -= this.eps * s;
            }
            nextStep = s <= this.shrink * p ? (nextStep += this.learningRate * d) : (nextStep += d * s / (p - s));
        } else {
            nextStep -= this.eps * s;
        }
        this.lastDelta[index] = nextStep;
        this.getLastGradient()[index] = gradients[index];
        return nextStep;
    }

    public double getDecay() {
        return this.decay;
    }

    public double getOutputEpsilon() {
        return this.outputEpsilon;
    }

    public double getShrink() {
        return this.shrink;
    }

    public void setShrink(double shrink) {
        this.shrink = shrink;
    }

    public void setOutputEpsilon(double outputEpsilon) {
        this.outputEpsilon = outputEpsilon;
    }

    public double[] getLastDelta() {
        return this.lastDelta;
    }

    public void setLastDelta(double[] lastDelta) {
        this.lastDelta = lastDelta;
    }

    public double getEps() {
        return this.eps;
    }

    public void setDecay(double decay) {
        this.decay = decay;
    }
}

