/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.flat.train.prop;

import org.encog.EncogError;
import org.encog.engine.network.activation.ActivationFunction;
import org.encog.engine.network.activation.ActivationSigmoid;
import org.encog.mathutil.IntRange;
import org.encog.ml.data.MLDataSet;
import org.encog.neural.error.ErrorFunction;
import org.encog.neural.error.LinearErrorFunction;
import org.encog.neural.flat.FlatNetwork;
import org.encog.neural.flat.train.TrainFlatNetwork;
import org.encog.neural.flat.train.prop.GradientWorker;
import org.encog.util.EngineArray;
import org.encog.util.concurrency.DetermineWorkload;
import org.encog.util.concurrency.EngineConcurrency;
import org.encog.util.concurrency.TaskGroup;

public abstract class TrainFlatNetworkProp
implements TrainFlatNetwork {
    private int numThreads;
    protected double[] gradients;
    private final double[] lastGradient;
    protected final FlatNetwork network;
    private final MLDataSet training;
    private final MLDataSet indexable;
    private GradientWorker[] workers;
    private double totalError;
    protected double currentError;
    protected double lastError;
    private Throwable reportedException;
    private int iteration;
    private double[] flatSpot;
    private boolean shouldFixFlatSpot;
    private ErrorFunction ef = new LinearErrorFunction();

    public TrainFlatNetworkProp(FlatNetwork network, MLDataSet training) {
        this.training = training;
        this.network = network;
        this.gradients = new double[this.network.getWeights().length];
        this.lastGradient = new double[this.network.getWeights().length];
        this.indexable = training;
        this.numThreads = 0;
        this.reportedException = null;
        this.shouldFixFlatSpot = true;
    }

    public void calculateGradients() {
        if (this.workers == null) {
            this.init();
        }
        if (this.network.getHasContext()) {
            this.workers[0].getNetwork().clearContext();
        }
        this.totalError = 0.0;
        if (this.workers.length > 1) {
            TaskGroup group = EngineConcurrency.getInstance().createTaskGroup();
            for (GradientWorker worker : this.workers) {
                EngineConcurrency.getInstance().processTask(worker, group);
            }
            group.waitForComplete();
        } else {
            this.workers[0].run();
        }
        this.currentError = this.totalError / (double)this.workers.length;
    }

    private void copyContexts() {
        for (int i = 0; i < this.workers.length - 1; ++i) {
            double[] src = this.workers[i].getNetwork().getLayerOutput();
            double[] dst = this.workers[i + 1].getNetwork().getLayerOutput();
            EngineArray.arrayCopy(src, dst);
        }
        EngineArray.arrayCopy(this.workers[this.workers.length - 1].getNetwork().getLayerOutput(), this.network.getLayerOutput());
    }

    @Override
    public void finishTraining() {
    }

    @Override
    public final double getError() {
        return this.currentError;
    }

    @Override
    public final int getIteration() {
        return this.iteration;
    }

    public final double[] getLastGradient() {
        return this.lastGradient;
    }

    @Override
    public final FlatNetwork getNetwork() {
        return this.network;
    }

    @Override
    public final int getNumThreads() {
        return this.numThreads;
    }

    @Override
    public final MLDataSet getTraining() {
        return this.training;
    }

    public void fixFlatSpot(boolean e) {
        this.shouldFixFlatSpot = e;
    }

    private void init() {
        this.flatSpot = new double[this.network.getActivationFunctions().length];
        if (this.shouldFixFlatSpot) {
            for (int i = 0; i < this.network.getActivationFunctions().length; ++i) {
                ActivationFunction af = this.network.getActivationFunctions()[i];
                this.flatSpot[i] = af instanceof ActivationSigmoid ? 0.1 : 0.0;
            }
        } else {
            EngineArray.fill(this.flatSpot, 0.0);
        }
        DetermineWorkload determine = new DetermineWorkload(this.numThreads, (int)this.indexable.getRecordCount());
        this.workers = new GradientWorker[determine.getThreadCount()];
        int index = 0;
        for (IntRange r : determine.calculateWorkers()) {
            this.workers[index++] = new GradientWorker(this.network.clone(), this, this.indexable.openAdditional(), r.getLow(), r.getHigh(), this.flatSpot, this.ef);
        }
        this.initOthers();
    }

    @Override
    public void iteration() {
        ++this.iteration;
        this.calculateGradients();
        if (this.network.isLimited()) {
            this.learnLimited();
        } else {
            this.learn();
        }
        this.lastError = this.currentError;
        for (GradientWorker worker : this.workers) {
            EngineArray.arrayCopy(this.network.getWeights(), 0, worker.getWeights(), 0, this.network.getWeights().length);
        }
        if (this.network.getHasContext()) {
            this.copyContexts();
        }
        if (this.reportedException != null) {
            throw new EncogError(this.reportedException);
        }
    }

    @Override
    public final void iteration(int count) {
        for (int i = 0; i < count; ++i) {
            this.iteration();
        }
    }

    protected void learn() {
        double[] weights = this.network.getWeights();
        for (int i = 0; i < this.gradients.length; ++i) {
            int n = i;
            weights[n] = weights[n] + this.updateWeight(this.gradients, this.lastGradient, i);
            this.gradients[i] = 0.0;
        }
    }

    protected void learnLimited() {
        double limit = this.network.getConnectionLimit();
        double[] weights = this.network.getWeights();
        for (int i = 0; i < this.gradients.length; ++i) {
            if (Math.abs(weights[i]) < limit) {
                weights[i] = 0.0;
            } else {
                int n = i;
                weights[n] = weights[n] + this.updateWeight(this.gradients, this.lastGradient, i);
            }
            this.gradients[i] = 0.0;
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public final void report(double[] gradients, double error, Throwable ex) {
        TrainFlatNetworkProp trainFlatNetworkProp = this;
        synchronized (trainFlatNetworkProp) {
            if (ex == null) {
                for (int i = 0; i < gradients.length; ++i) {
                    int n = i;
                    this.gradients[n] = this.gradients[n] + gradients[i];
                }
                this.totalError += error;
            } else {
                this.reportedException = ex;
            }
        }
    }

    @Override
    public void setIteration(int iteration) {
        this.iteration = iteration;
    }

    @Override
    public void setNumThreads(int numThreads) {
        this.numThreads = numThreads;
    }

    public abstract double updateWeight(double[] var1, double[] var2, int var3);

    public void setErrorFunction(ErrorFunction ef) {
        this.ef = ef;
    }

    public ErrorFunction getErrorFunction() {
        return this.ef;
    }

    public abstract void initOthers();
}

