/*
 * Decompiled with CFR 0.152.
 */
package ml.shifu.guagua.mapreduce.example.nn;

import java.util.Arrays;
import org.encog.engine.network.activation.ActivationFunction;
import org.encog.mathutil.error.ErrorCalculation;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLDataPair;
import org.encog.neural.error.ErrorFunction;
import org.encog.neural.flat.FlatNetwork;
import org.encog.neural.networks.BasicNetwork;

public class Gradient {
    private FlatNetwork network;
    private final ErrorCalculation errorCalculation = new ErrorCalculation();
    private final double[] actual;
    private final double[] layerDelta;
    private final int[] layerCounts;
    private final int[] layerFeedCounts;
    private final int[] layerIndex;
    private final int[] weightIndex;
    private final double[] layerOutput;
    private final double[] layerSums;
    private double[] gradients;
    private double[] weights;
    private final MLDataPair pair;
    private final MLDataSet training;
    private double error;
    private double[] flatSpot;
    private final ErrorFunction errorFunction;

    public Gradient(FlatNetwork theNetwork, MLDataSet theTraining, double[] flatSpot, ErrorFunction ef) {
        this.network = theNetwork;
        this.training = theTraining;
        this.flatSpot = flatSpot;
        this.errorFunction = ef;
        this.layerDelta = new double[this.getNetwork().getLayerOutput().length];
        this.gradients = new double[this.getNetwork().getWeights().length];
        this.actual = new double[this.getNetwork().getOutputCount()];
        this.weights = this.getNetwork().getWeights();
        this.layerIndex = this.getNetwork().getLayerIndex();
        this.layerCounts = this.getNetwork().getLayerCounts();
        this.weightIndex = this.getNetwork().getWeightIndex();
        this.layerOutput = this.getNetwork().getLayerOutput();
        this.layerSums = this.getNetwork().getLayerSums();
        this.layerFeedCounts = this.getNetwork().getLayerFeedCounts();
        this.pair = BasicMLDataPair.createPair((int)this.getNetwork().getInputCount(), (int)this.getNetwork().getOutputCount());
    }

    private void process(double[] input, double[] ideal, double s) {
        int i;
        this.getNetwork().compute(input, this.actual);
        this.errorCalculation.updateError(this.actual, ideal, s);
        this.errorFunction.calculateError(ideal, this.actual, this.getLayerDelta());
        for (i = 0; i < this.actual.length; ++i) {
            this.getLayerDelta()[i] = (this.getNetwork().getActivationFunctions()[0].derivativeFunction(this.layerSums[i], this.layerOutput[i]) + this.flatSpot[0]) * (this.getLayerDelta()[i] * s);
        }
        for (i = this.getNetwork().getBeginTraining(); i < this.getNetwork().getEndTraining(); ++i) {
            this.processLevel(i);
        }
    }

    private void processLevel(int currentLevel) {
        int fromLayerIndex = this.layerIndex[currentLevel + 1];
        int toLayerIndex = this.layerIndex[currentLevel];
        int fromLayerSize = this.layerCounts[currentLevel + 1];
        int toLayerSize = this.layerFeedCounts[currentLevel];
        int index = this.weightIndex[currentLevel];
        ActivationFunction activation = this.getNetwork().getActivationFunctions()[currentLevel + 1];
        double currentFlatSpot = this.flatSpot[currentLevel + 1];
        int yi = fromLayerIndex;
        for (int y = 0; y < fromLayerSize; ++y) {
            double output = this.layerOutput[yi];
            double sum = 0.0;
            int xi = toLayerIndex;
            int wi = index + y;
            for (int x = 0; x < toLayerSize; ++x) {
                int n = wi;
                this.gradients[n] = this.gradients[n] + output * this.getLayerDelta()[xi];
                sum += this.weights[wi] * this.getLayerDelta()[xi];
                wi += fromLayerSize;
                ++xi;
            }
            this.getLayerDelta()[yi] = sum * (activation.derivativeFunction(this.layerSums[yi], this.layerOutput[yi]) + currentFlatSpot);
            ++yi;
        }
    }

    public final void run() {
        try {
            this.errorCalculation.reset();
            Arrays.fill(this.gradients, 0.0);
            int i = 0;
            while ((long)i < this.training.getRecordCount()) {
                this.training.getRecord((long)i, this.pair);
                this.process(this.pair.getInputArray(), this.pair.getIdealArray(), this.pair.getSignificance());
                ++i;
            }
            this.error = this.errorCalculation.calculate();
        }
        catch (Throwable ex) {
            throw new RuntimeException(ex);
        }
    }

    public ErrorCalculation getErrorCalculation() {
        return this.errorCalculation;
    }

    public double[] getGradients() {
        return this.gradients;
    }

    public double getError() {
        return this.error;
    }

    public double[] getWeights() {
        return this.weights;
    }

    public void setWeights(double[] weights) {
        this.weights = weights;
        this.getNetwork().setWeights(weights);
    }

    public void setParams(BasicNetwork network) {
        this.network = network.getFlat();
        this.weights = network.getFlat().getWeights();
    }

    public FlatNetwork getNetwork() {
        return this.network;
    }

    public double[] getLayerDelta() {
        return this.layerDelta;
    }
}

