/*
 * Decompiled with CFR 0.152.
 */
package org.encog.mathutil.matrices.hessian;

import org.encog.mathutil.IntRange;
import org.encog.mathutil.matrices.Matrix;
import org.encog.mathutil.matrices.hessian.BasicHessian;
import org.encog.mathutil.matrices.hessian.ChainRuleWorker;
import org.encog.ml.data.MLDataSet;
import org.encog.neural.networks.BasicNetwork;
import org.encog.util.EngineArray;
import org.encog.util.concurrency.DetermineWorkload;
import org.encog.util.concurrency.EngineConcurrency;
import org.encog.util.concurrency.MultiThreadable;
import org.encog.util.concurrency.TaskGroup;

public class HessianCR
extends BasicHessian
implements MultiThreadable {
    private int numThreads;
    private ChainRuleWorker[] workers;

    @Override
    public void init(BasicNetwork theNetwork, MLDataSet theTraining) {
        super.init(theNetwork, theTraining);
        int weightCount = theNetwork.getStructure().getFlat().getWeights().length;
        this.training = theTraining;
        this.network = theNetwork;
        this.hessianMatrix = new Matrix(weightCount, weightCount);
        this.hessian = this.hessianMatrix.getData();
        DetermineWorkload determine = new DetermineWorkload(this.numThreads, (int)this.training.getRecordCount());
        this.workers = new ChainRuleWorker[determine.getThreadCount()];
        int index = 0;
        for (IntRange r : determine.calculateWorkers()) {
            this.workers[index++] = new ChainRuleWorker(this.flat.clone(), this.training.openAdditional(), r.getLow(), r.getHigh());
        }
    }

    @Override
    public void compute() {
        this.clear();
        double e = 0.0;
        int weightCount = this.network.getFlat().getWeights().length;
        for (int outputNeuron = 0; outputNeuron < this.network.getOutputCount(); ++outputNeuron) {
            if (this.flat.getHasContext()) {
                this.workers[0].getNetwork().clearContext();
            }
            if (this.workers.length > 1) {
                TaskGroup group = EngineConcurrency.getInstance().createTaskGroup();
                ChainRuleWorker[] chainRuleWorkerArray = this.workers;
                int n = chainRuleWorkerArray.length;
                for (int i = 0; i < n; ++i) {
                    ChainRuleWorker worker = chainRuleWorkerArray[i];
                    worker.setOutputNeuron(outputNeuron);
                    EngineConcurrency.getInstance().processTask(worker, group);
                }
                group.waitForComplete();
            } else {
                this.workers[0].setOutputNeuron(outputNeuron);
                this.workers[0].run();
            }
            for (ChainRuleWorker worker : this.workers) {
                e += worker.getError();
                for (int i = 0; i < weightCount; ++i) {
                    int n = i;
                    this.gradients[n] = this.gradients[n] + worker.getGradients()[i];
                }
                EngineArray.arrayAdd(this.getHessian(), worker.getHessian());
            }
        }
        this.sse = e / 2.0;
    }

    @Override
    public final void setThreadCount(int numThreads) {
        this.numThreads = numThreads;
    }

    @Override
    public int getThreadCount() {
        return this.numThreads;
    }
}

