/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.networks.training.propagation;

import org.encog.EncogError;
import org.encog.engine.network.activation.ActivationFunction;
import org.encog.engine.network.activation.ActivationSigmoid;
import org.encog.mathutil.IntRange;
import org.encog.ml.MLMethod;
import org.encog.ml.TrainingImplementationType;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.train.BasicTraining;
import org.encog.neural.error.ErrorFunction;
import org.encog.neural.error.LinearErrorFunction;
import org.encog.neural.flat.FlatNetwork;
import org.encog.neural.networks.ContainsFlat;
import org.encog.neural.networks.training.Train;
import org.encog.neural.networks.training.propagation.GradientWorker;
import org.encog.util.EncogValidate;
import org.encog.util.EngineArray;
import org.encog.util.concurrency.DetermineWorkload;
import org.encog.util.concurrency.EngineConcurrency;
import org.encog.util.concurrency.MultiThreadable;
import org.encog.util.concurrency.TaskGroup;
import org.encog.util.logging.EncogLogging;

public abstract class Propagation
extends BasicTraining
implements Train,
MultiThreadable {
    private FlatNetwork currentFlatNetwork;
    private int numThreads;
    protected double[] gradients;
    private final double[] lastGradient;
    protected final ContainsFlat network;
    private final MLDataSet indexable;
    private GradientWorker[] workers;
    private double totalError;
    protected double lastError;
    private Throwable reportedException;
    private int iteration;
    private double[] flatSpot;
    private boolean shouldFixFlatSpot;
    private ErrorFunction ef = new LinearErrorFunction();

    public Propagation(ContainsFlat network, MLDataSet training) {
        super(TrainingImplementationType.Iterative);
        this.network = network;
        this.currentFlatNetwork = network.getFlat();
        this.setTraining(training);
        this.gradients = new double[this.currentFlatNetwork.getWeights().length];
        this.lastGradient = new double[this.currentFlatNetwork.getWeights().length];
        this.indexable = training;
        this.numThreads = 0;
        this.reportedException = null;
        this.shouldFixFlatSpot = true;
    }

    @Override
    public final void finishTraining() {
        super.finishTraining();
    }

    public final FlatNetwork getCurrentFlatNetwork() {
        return this.currentFlatNetwork;
    }

    @Override
    public final MLMethod getMethod() {
        return this.network;
    }

    @Override
    public void iteration() {
        this.iteration(1);
    }

    public void rollIteration() {
        ++this.iteration;
    }

    @Override
    public final void iteration(int count) {
        try {
            for (int i = 0; i < count; ++i) {
                this.preIteration();
                this.rollIteration();
                this.calculateGradients();
                if (this.currentFlatNetwork.isLimited()) {
                    this.learnLimited();
                } else {
                    this.learn();
                }
                this.lastError = this.getError();
                for (GradientWorker worker : this.workers) {
                    EngineArray.arrayCopy(this.currentFlatNetwork.getWeights(), 0, worker.getWeights(), 0, this.currentFlatNetwork.getWeights().length);
                }
                if (this.currentFlatNetwork.getHasContext()) {
                    this.copyContexts();
                }
                if (this.reportedException != null) {
                    throw new EncogError(this.reportedException);
                }
                this.postIteration();
                EncogLogging.log(1, "Training iteration done, error: " + this.getError());
            }
        }
        catch (ArrayIndexOutOfBoundsException ex) {
            EncogValidate.validateNetworkForTraining(this.network, this.getTraining());
            throw new EncogError(ex);
        }
    }

    @Override
    public final void setThreadCount(int numThreads) {
        this.numThreads = numThreads;
    }

    @Override
    public int getThreadCount() {
        return this.numThreads;
    }

    public void fixFlatSpot(boolean b) {
        this.shouldFixFlatSpot = b;
    }

    public void setErrorFunction(ErrorFunction ef) {
        this.ef = ef;
    }

    public void calculateGradients() {
        if (this.workers == null) {
            this.init();
        }
        if (this.currentFlatNetwork.getHasContext()) {
            this.workers[0].getNetwork().clearContext();
        }
        this.totalError = 0.0;
        if (this.workers.length > 1) {
            TaskGroup group = EngineConcurrency.getInstance().createTaskGroup();
            for (GradientWorker worker : this.workers) {
                EngineConcurrency.getInstance().processTask(worker, group);
            }
            group.waitForComplete();
        } else {
            this.workers[0].run();
        }
        this.setError(this.totalError / (double)this.workers.length);
    }

    private void copyContexts() {
        for (int i = 0; i < this.workers.length - 1; ++i) {
            double[] src = this.workers[i].getNetwork().getLayerOutput();
            double[] dst = this.workers[i + 1].getNetwork().getLayerOutput();
            EngineArray.arrayCopy(src, dst);
        }
        EngineArray.arrayCopy(this.workers[this.workers.length - 1].getNetwork().getLayerOutput(), this.currentFlatNetwork.getLayerOutput());
    }

    private void init() {
        this.flatSpot = new double[this.currentFlatNetwork.getActivationFunctions().length];
        if (this.shouldFixFlatSpot) {
            for (int i = 0; i < this.currentFlatNetwork.getActivationFunctions().length; ++i) {
                ActivationFunction af = this.currentFlatNetwork.getActivationFunctions()[i];
                this.flatSpot[i] = af instanceof ActivationSigmoid ? 0.1 : 0.0;
            }
        } else {
            EngineArray.fill(this.flatSpot, 0.0);
        }
        DetermineWorkload determine = new DetermineWorkload(this.numThreads, (int)this.indexable.getRecordCount());
        this.workers = new GradientWorker[determine.getThreadCount()];
        int index = 0;
        for (IntRange r : determine.calculateWorkers()) {
            this.workers[index++] = new GradientWorker(this.currentFlatNetwork.clone(), this, this.indexable.openAdditional(), r.getLow(), r.getHigh(), this.flatSpot, this.ef);
        }
        this.initOthers();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public final void report(double[] gradients, double error, Throwable ex) {
        Propagation propagation = this;
        synchronized (propagation) {
            if (ex == null) {
                for (int i = 0; i < gradients.length; ++i) {
                    int n = i;
                    this.gradients[n] = this.gradients[n] + gradients[i];
                }
                this.totalError += error;
            } else {
                this.reportedException = ex;
            }
        }
    }

    protected void learn() {
        double[] weights = this.currentFlatNetwork.getWeights();
        for (int i = 0; i < this.gradients.length; ++i) {
            int n = i;
            weights[n] = weights[n] + this.updateWeight(this.gradients, this.lastGradient, i);
            this.gradients[i] = 0.0;
        }
    }

    protected void learnLimited() {
        double limit = this.currentFlatNetwork.getConnectionLimit();
        double[] weights = this.currentFlatNetwork.getWeights();
        for (int i = 0; i < this.gradients.length; ++i) {
            if (Math.abs(weights[i]) < limit) {
                weights[i] = 0.0;
            } else {
                int n = i;
                weights[n] = weights[n] + this.updateWeight(this.gradients, this.lastGradient, i);
            }
            this.gradients[i] = 0.0;
        }
    }

    public abstract void initOthers();

    public abstract double updateWeight(double[] var1, double[] var2, int var3);

    public double[] getLastGradient() {
        return this.lastGradient;
    }
}

