/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.networks.training.propagation.sgd;

import org.encog.EncogError;
import org.encog.engine.network.activation.ActivationFunction;
import org.encog.mathutil.error.ErrorCalculation;
import org.encog.mathutil.randomize.generate.GenerateRandom;
import org.encog.mathutil.randomize.generate.MersenneTwisterGenerateRandom;
import org.encog.ml.MLMethod;
import org.encog.ml.TrainingImplementationType;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.train.BasicTraining;
import org.encog.neural.error.CrossEntropyErrorFunction;
import org.encog.neural.error.ErrorFunction;
import org.encog.neural.flat.FlatNetwork;
import org.encog.neural.networks.ContainsFlat;
import org.encog.neural.networks.training.LearningRate;
import org.encog.neural.networks.training.Momentum;
import org.encog.neural.networks.training.propagation.TrainingContinuation;
import org.encog.neural.networks.training.propagation.sgd.BatchDataSet;
import org.encog.neural.networks.training.propagation.sgd.update.AdamUpdate;
import org.encog.neural.networks.training.propagation.sgd.update.UpdateRule;
import org.encog.util.EngineArray;

public class StochasticGradientDescent
extends BasicTraining
implements Momentum,
LearningRate {
    private double learningRate;
    private double momentum;
    private final double[] gradients;
    private final double[] layerDelta;
    private double l1;
    private double l2;
    private UpdateRule updateRule = new AdamUpdate();
    private double[] lastDelta;
    private FlatNetwork flat;
    private ErrorFunction errorFunction = new CrossEntropyErrorFunction();
    private ErrorCalculation errorCalculation;
    private GenerateRandom rnd;
    private MLMethod method;

    public StochasticGradientDescent(ContainsFlat network, MLDataSet training) {
        this(network, training, new MersenneTwisterGenerateRandom());
    }

    public StochasticGradientDescent(ContainsFlat network, MLDataSet training, GenerateRandom theRandom) {
        super(TrainingImplementationType.Iterative);
        this.setTraining(training);
        if (!(training instanceof BatchDataSet)) {
            this.setBatchSize(25);
        }
        this.method = network;
        this.flat = network.getFlat();
        this.layerDelta = new double[this.flat.getLayerOutput().length];
        this.gradients = new double[this.flat.getWeights().length];
        this.errorCalculation = new ErrorCalculation();
        this.rnd = theRandom;
        this.learningRate = 0.001;
        this.momentum = 0.9;
    }

    public void process(MLDataPair pair) {
        this.errorCalculation = new ErrorCalculation();
        double[] actual = new double[this.flat.getOutputCount()];
        this.flat.compute(pair.getInputArray(), actual);
        this.errorCalculation.updateError(actual, pair.getIdealArray(), pair.getSignificance());
        this.errorFunction.calculateError(this.flat.getActivationFunctions()[0], this.flat.getLayerSums(), this.flat.getLayerOutput(), pair.getIdeal().getData(), actual, this.layerDelta, 0.0, pair.getSignificance());
        if (this.l1 > 1.0E-13 || this.l2 > 1.0E-13) {
            double[] lp = new double[2];
            this.calculateRegularizationPenalty(lp);
            int i = 0;
            while (i < actual.length) {
                double p = lp[0] * this.l1 + lp[1] * this.l2;
                int n = i++;
                this.layerDelta[n] = this.layerDelta[n] + p;
            }
        }
        for (int i = this.flat.getBeginTraining(); i < this.flat.getEndTraining(); ++i) {
            this.processLevel(i);
        }
    }

    public void update() {
        if (this.getIteration() == 0) {
            this.updateRule.init(this);
        }
        this.preIteration();
        this.updateRule.update(this.gradients, this.flat.getWeights());
        this.setError(this.errorCalculation.calculate());
        this.postIteration();
        EngineArray.fill(this.gradients, 0.0);
        this.errorCalculation.reset();
        if (this.getTraining() instanceof BatchDataSet) {
            ((BatchDataSet)this.getTraining()).advance();
        }
    }

    public void resetError() {
        this.errorCalculation.reset();
    }

    private void processLevel(int currentLevel) {
        int fromLayerIndex = this.flat.getLayerIndex()[currentLevel + 1];
        int toLayerIndex = this.flat.getLayerIndex()[currentLevel];
        int fromLayerSize = this.flat.getLayerCounts()[currentLevel + 1];
        int toLayerSize = this.flat.getLayerFeedCounts()[currentLevel];
        double dropoutRate = 0.0;
        int index = this.flat.getWeightIndex()[currentLevel];
        ActivationFunction activation = this.flat.getActivationFunctions()[currentLevel];
        double[] layerDelta = this.layerDelta;
        double[] weights = this.flat.getWeights();
        double[] gradients = this.gradients;
        double[] layerOutput = this.flat.getLayerOutput();
        double[] layerSums = this.flat.getLayerSums();
        int yi = fromLayerIndex;
        for (int y = 0; y < fromLayerSize; ++y) {
            double output = layerOutput[yi];
            double sum = 0.0;
            int wi = index + y;
            int loopEnd = toLayerIndex + toLayerSize;
            int xi = toLayerIndex;
            while (xi < loopEnd) {
                int n = wi;
                gradients[n] = gradients[n] + output * layerDelta[xi];
                sum += weights[wi] * layerDelta[xi];
                ++xi;
                wi += fromLayerSize;
            }
            layerDelta[yi] = sum * activation.derivativeFunction(layerSums[yi], layerOutput[yi]);
            ++yi;
        }
    }

    @Override
    public void iteration() {
        for (int i = 0; i < this.getTraining().size(); ++i) {
            this.process(this.getTraining().get(i));
        }
        if (this.getIteration() == 0) {
            this.updateRule.init(this);
        }
        this.preIteration();
        this.update();
        this.postIteration();
        if (this.getTraining() instanceof BatchDataSet) {
            ((BatchDataSet)this.getTraining()).advance();
        }
    }

    @Override
    public boolean canContinue() {
        return false;
    }

    @Override
    public double getLearningRate() {
        return this.learningRate;
    }

    @Override
    public double getMomentum() {
        return this.momentum;
    }

    public boolean isValidResume(TrainingContinuation state) {
        return false;
    }

    @Override
    public TrainingContinuation pause() {
        return null;
    }

    @Override
    public void resume(TrainingContinuation state) {
        throw new EncogError("Resume not currently supported.");
    }

    @Override
    public MLMethod getMethod() {
        return this.method;
    }

    @Override
    public void setLearningRate(double rate) {
        this.learningRate = rate;
    }

    @Override
    public void setMomentum(double m) {
        this.momentum = m;
    }

    @Override
    public void preIteration() {
        super.preIteration();
    }

    public int getBatchSize() {
        if (this.getTraining() instanceof BatchDataSet) {
            return ((BatchDataSet)this.getTraining()).getBatchSize();
        }
        return 0;
    }

    public void setBatchSize(int theBatchSize) {
        if (this.getTraining() instanceof BatchDataSet) {
            ((BatchDataSet)this.getTraining()).setBatchSize(theBatchSize);
        } else {
            BatchDataSet batchSet = new BatchDataSet(this.getTraining(), this.rnd);
            this.setTraining(batchSet);
        }
    }

    public double getL1() {
        return this.l1;
    }

    public void setL1(double l1) {
        this.l1 = l1;
    }

    public double getL2() {
        return this.l2;
    }

    public void setL2(double l2) {
        this.l2 = l2;
    }

    public void calculateRegularizationPenalty(double[] l) {
        for (int i = 0; i < this.flat.getLayerCounts().length - 1; ++i) {
            this.layerRegularizationPenalty(i, l);
        }
    }

    public void layerRegularizationPenalty(int fromLayer, double[] l) {
        int fromCount = this.flat.getLayerTotalNeuronCount(fromLayer);
        int toCount = this.flat.getLayerNeuronCount(fromLayer + 1);
        for (int fromNeuron = 0; fromNeuron < fromCount; ++fromNeuron) {
            for (int toNeuron = 0; toNeuron < toCount; ++toNeuron) {
                double w = this.flat.getWeight(fromLayer, fromNeuron, toNeuron);
                l[0] = l[0] + Math.abs(w);
                l[1] = l[1] + w * w;
            }
        }
    }

    public FlatNetwork getFlat() {
        return this.flat;
    }

    public UpdateRule getUpdateRule() {
        return this.updateRule;
    }

    public void setUpdateRule(UpdateRule updateRule) {
        this.updateRule = updateRule;
    }
}

