/*
 * Decompiled with CFR 0.152.
 */
package org.encog.engine.network.train.prop;

import org.encog.engine.EncogEngineError;
import org.encog.engine.concurrency.DetermineWorkload;
import org.encog.engine.concurrency.EngineConcurrency;
import org.encog.engine.concurrency.TaskGroup;
import org.encog.engine.data.EngineDataSet;
import org.encog.engine.data.EngineIndexableSet;
import org.encog.engine.network.flat.FlatNetwork;
import org.encog.engine.network.train.TrainFlatNetwork;
import org.encog.engine.network.train.gradient.FlatGradientWorker;
import org.encog.engine.network.train.gradient.GradientWorkerCPU;
import org.encog.engine.util.EngineArray;
import org.encog.engine.util.IntRange;

public abstract class TrainFlatNetworkProp
implements TrainFlatNetwork {
    private int numThreads;
    protected double[] gradients;
    private double[] lastGradient;
    protected final FlatNetwork network;
    private final EngineDataSet training;
    private final EngineIndexableSet indexable;
    private FlatGradientWorker[] workers;
    private double totalError;
    protected double currentError;
    private Throwable reportedException;
    private int iteration;

    public TrainFlatNetworkProp(FlatNetwork network, EngineDataSet training) {
        if (!(training instanceof EngineIndexableSet)) {
            throw new EncogEngineError("Training data must be Indexable for this training type.");
        }
        this.training = training;
        this.network = network;
        this.gradients = new double[this.network.getWeights().length];
        this.lastGradient = new double[this.network.getWeights().length];
        this.indexable = (EngineIndexableSet)training;
        this.numThreads = 0;
        this.reportedException = null;
    }

    public void calculateGradients() {
        if (this.workers == null) {
            this.init();
        }
        this.workers[0].getNetwork().clearContext();
        this.totalError = 0.0;
        if (this.workers.length > 1) {
            TaskGroup group = EngineConcurrency.getInstance().createTaskGroup();
            for (FlatGradientWorker worker : this.workers) {
                EngineConcurrency.getInstance().processTask(worker, group);
            }
            group.waitForComplete();
        } else {
            this.workers[0].run();
        }
        this.currentError = this.totalError / (double)this.workers.length;
    }

    private void copyContexts() {
        for (int i = 0; i < this.workers.length - 1; ++i) {
            double[] src = this.workers[i].getNetwork().getLayerOutput();
            double[] dst = this.workers[i + 1].getNetwork().getLayerOutput();
            EngineArray.arrayCopy(src, dst);
        }
        EngineArray.arrayCopy(this.workers[this.workers.length - 1].getNetwork().getLayerOutput(), this.network.getLayerOutput());
    }

    @Override
    public void finishTraining() {
    }

    @Override
    public double getError() {
        return this.currentError;
    }

    public double[] getLastGradient() {
        return this.lastGradient;
    }

    @Override
    public FlatNetwork getNetwork() {
        return this.network;
    }

    @Override
    public int getNumThreads() {
        return this.numThreads;
    }

    @Override
    public EngineDataSet getTraining() {
        return this.training;
    }

    private void init() {
        DetermineWorkload determine = new DetermineWorkload(this.numThreads, (int)this.indexable.getRecordCount());
        this.workers = new FlatGradientWorker[determine.getThreadCount()];
        int index = 0;
        for (IntRange r : determine.calculateWorkers()) {
            this.workers[index++] = new GradientWorkerCPU(this.network.clone(), this, this.indexable.openAdditional(), r.getLow(), r.getHigh());
        }
    }

    @Override
    public void iteration() {
        ++this.iteration;
        this.calculateGradients();
        if (this.network.isLimited()) {
            this.learnLimited();
        } else {
            this.learn();
        }
        for (FlatGradientWorker worker : this.workers) {
            EngineArray.arrayCopy(this.network.getWeights(), 0, worker.getWeights(), 0, this.network.getWeights().length);
        }
        this.copyContexts();
        if (this.reportedException != null) {
            throw new EncogEngineError(this.reportedException);
        }
    }

    protected void learn() {
        double[] weights = this.network.getWeights();
        for (int i = 0; i < this.gradients.length; ++i) {
            int n = i;
            weights[n] = weights[n] + this.updateWeight(this.gradients, this.lastGradient, i);
            this.gradients[i] = 0.0;
        }
    }

    protected void learnLimited() {
        double limit = this.network.getConnectionLimit();
        double[] weights = this.network.getWeights();
        for (int i = 0; i < this.gradients.length; ++i) {
            if (weights[i] < limit) {
                weights[i] = 0.0;
            } else {
                int n = i;
                weights[n] = weights[n] + this.updateWeight(this.gradients, this.lastGradient, i);
            }
            this.gradients[i] = 0.0;
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void report(double[] gradients, double error, Throwable ex) {
        TrainFlatNetworkProp trainFlatNetworkProp = this;
        synchronized (trainFlatNetworkProp) {
            if (ex == null) {
                for (int i = 0; i < gradients.length; ++i) {
                    int n = i;
                    this.gradients[n] = this.gradients[n] + gradients[i];
                }
                this.totalError += error;
            } else {
                this.reportedException = ex;
            }
        }
    }

    @Override
    public void setNumThreads(int numThreads) {
        this.numThreads = numThreads;
    }

    public abstract double updateWeight(double[] var1, double[] var2, int var3);

    @Override
    public void iteration(int count) {
        for (int i = 0; i < count; ++i) {
            this.iteration();
        }
    }

    @Override
    public int getIteration() {
        return this.iteration;
    }

    @Override
    public void setIteration(int iteration) {
        this.iteration = iteration;
    }
}

