/*
 * Decompiled with CFR 0.152.
 */
package org.encog.engine.network.train.gradient;

import org.encog.engine.data.BasicEngineData;
import org.encog.engine.data.EngineData;
import org.encog.engine.data.EngineIndexableSet;
import org.encog.engine.network.activation.ActivationFunction;
import org.encog.engine.network.flat.FlatNetwork;
import org.encog.engine.network.train.gradient.FlatGradientWorker;
import org.encog.engine.network.train.prop.TrainFlatNetworkProp;
import org.encog.engine.util.EngineArray;
import org.encog.engine.util.ErrorCalculation;
import org.encog.engine.util.Stopwatch;

public class GradientWorkerCPU
implements FlatGradientWorker {
    private final FlatNetwork network;
    private final ErrorCalculation errorCalculation = new ErrorCalculation();
    private final double[] actual;
    private final double[] layerDelta;
    private final int[] layerCounts;
    private final int[] layerFeedCounts;
    private final int[] layerIndex;
    private final int[] weightIndex;
    private final double[] layerOutput;
    private final double[] gradients;
    private final double[] weights;
    private final EngineData pair;
    private final EngineIndexableSet training;
    private final int low;
    private final int high;
    private final TrainFlatNetworkProp owner;
    private long elapsedTime;
    private final Stopwatch stopwatch;

    public GradientWorkerCPU(FlatNetwork network, TrainFlatNetworkProp owner, EngineIndexableSet training, int low, int high) {
        this.network = network;
        this.training = training;
        this.low = low;
        this.high = high;
        this.owner = owner;
        this.stopwatch = new Stopwatch();
        this.layerDelta = new double[network.getLayerOutput().length];
        this.gradients = new double[network.getWeights().length];
        this.actual = new double[network.getOutputCount()];
        this.weights = network.getWeights();
        this.layerIndex = network.getLayerIndex();
        this.layerCounts = network.getLayerCounts();
        this.weightIndex = network.getWeightIndex();
        this.layerOutput = network.getLayerOutput();
        this.layerFeedCounts = network.getLayerFeedCounts();
        this.pair = BasicEngineData.createPair(network.getInputCount(), network.getOutputCount());
    }

    @Override
    public long getElapsedTime() {
        return this.elapsedTime;
    }

    @Override
    public FlatNetwork getNetwork() {
        return this.network;
    }

    @Override
    public double[] getWeights() {
        return this.weights;
    }

    private void process(double[] input, double[] ideal) {
        int i;
        this.network.compute(input, this.actual);
        this.errorCalculation.updateError(this.actual, ideal);
        for (i = 0; i < this.actual.length; ++i) {
            this.layerDelta[i] = this.network.getActivationFunctions()[0].derivativeFunction(this.actual[i]) * (ideal[i] - this.actual[i]);
        }
        for (i = this.network.getBeginTraining(); i < this.network.getEndTraining(); ++i) {
            this.processLevel(i);
        }
    }

    private void processLevel(int currentLevel) {
        int fromLayerIndex = this.layerIndex[currentLevel + 1];
        int toLayerIndex = this.layerIndex[currentLevel];
        int fromLayerSize = this.layerCounts[currentLevel + 1];
        int toLayerSize = this.layerFeedCounts[currentLevel];
        int index = this.weightIndex[currentLevel];
        ActivationFunction activation = this.network.getActivationFunctions()[currentLevel + 1];
        int yi = fromLayerIndex;
        for (int y = 0; y < fromLayerSize; ++y) {
            double output = this.layerOutput[yi];
            double sum = 0.0;
            int xi = toLayerIndex;
            int wi = index + y;
            for (int x = 0; x < toLayerSize; ++x) {
                int n = wi;
                this.gradients[n] = this.gradients[n] + output * this.layerDelta[xi];
                sum += this.weights[wi] * this.layerDelta[xi];
                wi += fromLayerSize;
                ++xi;
            }
            this.layerDelta[yi] = sum * activation.derivativeFunction(this.layerOutput[yi]);
            ++yi;
        }
    }

    @Override
    public void run() {
        try {
            this.stopwatch.reset();
            this.stopwatch.start();
            this.errorCalculation.reset();
            for (int i = this.low; i <= this.high; ++i) {
                this.training.getRecord(i, this.pair);
                this.process(this.pair.getInputArray(), this.pair.getIdealArray());
            }
            double error = this.errorCalculation.calculate();
            this.owner.report(this.gradients, error, null);
            EngineArray.fill(this.gradients, 0.0);
            this.stopwatch.stop();
            this.elapsedTime = this.stopwatch.getElapsedTicks();
        }
        catch (Throwable ex) {
            this.owner.report(null, 0.0, ex);
        }
    }
}

