/*
 * Decompiled with CFR 0.152.
 */
package org.encog.engine.network.train.prop;

import org.encog.engine.data.EngineDataSet;
import org.encog.engine.network.flat.FlatNetwork;
import org.encog.engine.network.train.prop.TrainFlatNetworkProp;

public class TrainFlatNetworkResilient
extends TrainFlatNetworkProp {
    private final double[] updateValues;
    private final double zeroTolerance;
    private final double maxStep;

    public TrainFlatNetworkResilient(FlatNetwork network, EngineDataSet training, double zeroTolerance, double initialUpdate, double maxStep) {
        super(network, training);
        this.updateValues = new double[network.getWeights().length];
        this.zeroTolerance = zeroTolerance;
        this.maxStep = maxStep;
        for (int i = 0; i < this.updateValues.length; ++i) {
            this.updateValues[i] = initialUpdate;
        }
    }

    public TrainFlatNetworkResilient(FlatNetwork flat, EngineDataSet trainingSet) {
        this(flat, trainingSet, 1.0E-17, 0.1, 50.0);
    }

    private int sign(double value) {
        if (Math.abs(value) < this.zeroTolerance) {
            return 0;
        }
        if (value > 0.0) {
            return 1;
        }
        return -1;
    }

    @Override
    public double updateWeight(double[] gradients, double[] lastGradient, int index) {
        int change = this.sign(gradients[index] * lastGradient[index]);
        double weightChange = 0.0;
        if (change > 0) {
            double delta = this.updateValues[index] * 1.2;
            delta = Math.min(delta, this.maxStep);
            weightChange = (double)this.sign(gradients[index]) * delta;
            this.updateValues[index] = delta;
            lastGradient[index] = gradients[index];
        } else if (change < 0) {
            double delta = this.updateValues[index] * 0.5;
            this.updateValues[index] = delta = Math.max(delta, 1.0E-6);
            lastGradient[index] = 0.0;
        } else if (change == 0) {
            double delta = lastGradient[index];
            weightChange = (double)this.sign(gradients[index]) * delta;
            lastGradient[index] = gradients[index];
        }
        return weightChange;
    }

    public double[] getUpdateValues() {
        return this.updateValues;
    }
}

