/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.flat.train.prop;

import org.encog.ml.data.MLDataSet;
import org.encog.neural.flat.FlatNetwork;
import org.encog.neural.flat.train.prop.RPROPType;
import org.encog.neural.flat.train.prop.TrainFlatNetworkProp;
import org.encog.neural.networks.training.TrainingError;

public class TrainFlatNetworkResilient
extends TrainFlatNetworkProp {
    private final double[] updateValues;
    private final double[] lastDelta;
    private final double zeroTolerance;
    private final double maxStep;
    private RPROPType rpropType = RPROPType.RPROPp;
    private double[] lastWeightChange;

    public TrainFlatNetworkResilient(FlatNetwork network, MLDataSet training, double zeroTolerance, double initialUpdate, double maxStep) {
        super(network, training);
        this.updateValues = new double[network.getWeights().length];
        this.lastDelta = new double[network.getWeights().length];
        this.lastWeightChange = new double[network.getWeights().length];
        this.zeroTolerance = zeroTolerance;
        this.maxStep = maxStep;
        for (int i = 0; i < this.updateValues.length; ++i) {
            this.updateValues[i] = initialUpdate;
            this.lastDelta[i] = 0.0;
        }
    }

    public TrainFlatNetworkResilient(FlatNetwork flat, MLDataSet trainingSet) {
        this(flat, trainingSet, 1.0E-17, 0.1, 50.0);
    }

    private int sign(double value) {
        if (Math.abs(value) < this.zeroTolerance) {
            return 0;
        }
        if (value > 0.0) {
            return 1;
        }
        return -1;
    }

    @Override
    public double updateWeight(double[] gradients, double[] lastGradient, int index) {
        double weightChange = 0.0;
        switch (this.rpropType) {
            case RPROPp: {
                weightChange = this.updateWeightPlus(gradients, lastGradient, index);
                break;
            }
            case RPROPm: {
                weightChange = this.updateWeightMinus(gradients, lastGradient, index);
                break;
            }
            case iRPROPp: {
                weightChange = this.updateiWeightPlus(gradients, lastGradient, index);
                break;
            }
            case iRPROPm: {
                weightChange = this.updateiWeightMinus(gradients, lastGradient, index);
                break;
            }
            default: {
                throw new TrainingError("Unknown RPROP type: " + (Object)((Object)this.rpropType));
            }
        }
        this.lastWeightChange[index] = weightChange;
        return weightChange;
    }

    public double updateWeightPlus(double[] gradients, double[] lastGradient, int index) {
        int change = this.sign(gradients[index] * lastGradient[index]);
        double weightChange = 0.0;
        if (change > 0) {
            double delta = this.updateValues[index] * 1.2;
            delta = Math.min(delta, this.maxStep);
            weightChange = (double)this.sign(gradients[index]) * delta;
            this.updateValues[index] = delta;
            lastGradient[index] = gradients[index];
        } else if (change < 0) {
            double delta = this.updateValues[index] * 0.5;
            this.updateValues[index] = delta = Math.max(delta, 1.0E-6);
            weightChange = -this.lastWeightChange[index];
            lastGradient[index] = 0.0;
        } else if (change == 0) {
            double delta = this.updateValues[index];
            weightChange = (double)this.sign(gradients[index]) * delta;
            lastGradient[index] = gradients[index];
        }
        return weightChange;
    }

    public double updateWeightMinus(double[] gradients, double[] lastGradient, int index) {
        double delta;
        int change = this.sign(gradients[index] * lastGradient[index]);
        double weightChange = 0.0;
        if (change > 0) {
            delta = this.lastDelta[index] * 1.2;
            delta = Math.min(delta, this.maxStep);
        } else {
            delta = this.lastDelta[index] * 0.5;
            delta = Math.max(delta, 1.0E-6);
        }
        lastGradient[index] = gradients[index];
        weightChange = (double)this.sign(gradients[index]) * delta;
        this.lastDelta[index] = delta;
        return weightChange;
    }

    public double updateiWeightPlus(double[] gradients, double[] lastGradient, int index) {
        int change = this.sign(gradients[index] * lastGradient[index]);
        double weightChange = 0.0;
        if (change > 0) {
            double delta = this.updateValues[index] * 1.2;
            delta = Math.min(delta, this.maxStep);
            weightChange = (double)this.sign(gradients[index]) * delta;
            this.updateValues[index] = delta;
            lastGradient[index] = gradients[index];
        } else if (change < 0) {
            double delta = this.updateValues[index] * 0.5;
            this.updateValues[index] = delta = Math.max(delta, 1.0E-6);
            if (this.currentError > this.lastError) {
                weightChange = -this.lastWeightChange[index];
            }
            lastGradient[index] = 0.0;
        } else if (change == 0) {
            double delta = this.updateValues[index];
            weightChange = (double)this.sign(gradients[index]) * delta;
            lastGradient[index] = gradients[index];
        }
        return weightChange;
    }

    public double updateiWeightMinus(double[] gradients, double[] lastGradient, int index) {
        double delta;
        int change = this.sign(gradients[index] * lastGradient[index]);
        double weightChange = 0.0;
        if (change > 0) {
            delta = this.lastDelta[index] * 1.2;
            delta = Math.min(delta, this.maxStep);
        } else {
            delta = this.lastDelta[index] * 0.5;
            delta = Math.max(delta, 1.0E-6);
            lastGradient[index] = 0.0;
        }
        lastGradient[index] = gradients[index];
        weightChange = (double)this.sign(gradients[index]) * delta;
        this.lastDelta[index] = delta;
        return weightChange;
    }

    public double[] getUpdateValues() {
        return this.updateValues;
    }

    public RPROPType getRpropType() {
        return this.rpropType;
    }

    public void setRpropType(RPROPType rpropType) {
        this.rpropType = rpropType;
    }

    @Override
    public void initOthers() {
    }
}

