/*
 * Decompiled with CFR 0.152.
 */
package org.encog.engine.network.train.prop;

import org.encog.engine.data.EngineDataSet;
import org.encog.engine.network.flat.FlatNetwork;
import org.encog.engine.network.train.prop.TrainFlatNetworkProp;
import org.encog.engine.util.BoundNumbers;
import org.encog.engine.util.EngineArray;

public class TrainFlatNetworkSCG
extends TrainFlatNetworkProp {
    protected static final double FIRST_SIGMA = 1.0E-4;
    protected static final double FIRST_LAMBDA = 1.0E-6;
    private boolean restart = false;
    private double lambda2 = 0.0;
    private double lambda = 1.0E-6;
    private int k;
    private boolean success = true;
    private double magP = 0.0;
    private final double[] p;
    private final double[] r;
    private final double[] weights;
    private double delta = 0.0;
    private double oldError = 0.0;
    private final double[] oldWeights;
    private final double[] oldGradient;
    private boolean mustInit;

    public TrainFlatNetworkSCG(FlatNetwork network, EngineDataSet training) {
        super(network, training);
        this.weights = EngineArray.arrayCopy(network.getWeights());
        int numWeights = this.weights.length;
        this.oldWeights = new double[numWeights];
        this.oldGradient = new double[numWeights];
        this.p = new double[numWeights];
        this.r = new double[numWeights];
        this.mustInit = true;
    }

    @Override
    public void calculateGradients() {
        int outCount = this.getNetwork().getOutputCount();
        super.calculateGradients();
        double factor = -2.0 / (double)this.gradients.length / (double)outCount;
        int i = 0;
        while (i < this.gradients.length) {
            int n = i++;
            this.gradients[n] = this.gradients[n] * factor;
        }
    }

    private void init() {
        int numWeights = this.weights.length;
        this.calculateGradients();
        this.k = 1;
        for (int i = 0; i < numWeights; ++i) {
            this.p[i] = this.r[i] = -this.gradients[i];
        }
        this.mustInit = false;
    }

    @Override
    public void iteration() {
        if (this.mustInit) {
            this.init();
        }
        int numWeights = this.weights.length;
        if (this.restart) {
            this.lambda = 1.0E-6;
            this.lambda2 = 0.0;
            this.k = 1;
            this.success = true;
            this.restart = false;
        }
        if (this.success) {
            int i;
            this.magP = EngineArray.vectorProduct(this.p, this.p);
            double sigma = 1.0E-4 / Math.sqrt(this.magP);
            EngineArray.arrayCopy(this.gradients, this.oldGradient);
            EngineArray.arrayCopy(this.weights, this.oldWeights);
            this.oldError = this.getError();
            for (i = 0; i < numWeights; ++i) {
                int n = i;
                this.weights[n] = this.weights[n] + sigma * this.p[i];
            }
            EngineArray.arrayCopy(this.weights, this.network.getWeights());
            this.calculateGradients();
            this.delta = 0.0;
            for (i = 0; i < numWeights; ++i) {
                double step = (this.gradients[i] - this.oldGradient[i]) / sigma;
                this.delta += this.p[i] * step;
            }
        }
        this.delta += (this.lambda - this.lambda2) * this.magP;
        if (this.delta <= 0.0) {
            this.lambda2 = 2.0 * (this.lambda - this.delta / this.magP);
            this.delta = this.lambda * this.magP - this.delta;
            this.lambda = this.lambda2;
        }
        double mu = EngineArray.vectorProduct(this.p, this.r);
        double alpha = mu / this.delta;
        for (int i = 0; i < numWeights; ++i) {
            this.weights[i] = this.oldWeights[i] + alpha * this.p[i];
        }
        EngineArray.arrayCopy(this.weights, this.network.getWeights());
        this.calculateGradients();
        double gdelta = 2.0 * this.delta * (this.oldError - this.getError()) / (mu * mu);
        if (gdelta >= 0.0) {
            double rsum = 0.0;
            for (int i = 0; i < numWeights; ++i) {
                double tmp = -this.gradients[i];
                rsum += tmp * this.r[i];
                this.r[i] = tmp;
            }
            this.lambda2 = 0.0;
            this.success = true;
            if (this.k >= numWeights) {
                this.restart = true;
                EngineArray.arrayCopy(this.r, this.p);
            } else {
                double beta = (EngineArray.vectorProduct(this.r, this.r) - rsum) / mu;
                for (int i = 0; i < numWeights; ++i) {
                    this.p[i] = this.r[i] + beta * this.p[i];
                }
                this.restart = false;
            }
            if (gdelta >= 0.75) {
                this.lambda *= 0.25;
            }
        } else {
            EngineArray.arrayCopy(this.oldWeights, this.weights);
            this.currentError = this.oldError;
            this.lambda2 = this.lambda;
            this.success = false;
        }
        if (gdelta < 0.25) {
            this.lambda += this.delta * (1.0 - gdelta) / this.magP;
        }
        this.lambda = BoundNumbers.bound(this.lambda);
        ++this.k;
        EngineArray.arrayCopy(this.weights, this.network.getWeights());
    }

    @Override
    public double updateWeight(double[] gradients, double[] lastGradient, int index) {
        return 0.0;
    }
}

