/*
 * Decompiled with CFR 0.152.
 */
package org.encog.engine.network.train.prop;

import java.util.HashMap;
import java.util.Map;
import org.encog.engine.EncogEngine;
import org.encog.engine.EncogEngineError;
import org.encog.engine.data.EngineDataSet;
import org.encog.engine.data.EngineIndexableSet;
import org.encog.engine.network.flat.FlatNetwork;
import org.encog.engine.network.flat.ValidateForOpenCL;
import org.encog.engine.network.train.TrainFlatNetwork;
import org.encog.engine.network.train.prop.OpenCLTrainingProfile;
import org.encog.engine.opencl.kernels.KernelNetworkTrain;
import org.encog.engine.util.EngineArray;
import org.encog.engine.util.ErrorCalculation;
import org.encog.engine.util.ErrorCalculationMode;

public class TrainFlatNetworkOpenCL
implements TrainFlatNetwork {
    public static final int LEARN_RPROP = 0;
    public static final int LEARN_BPROP = 1;
    public static final int LEARN_MANHATTAN = 2;
    private double error;
    private final FlatNetwork network;
    private final EngineIndexableSet training;
    private int learningType;
    private double learningRate;
    private double momentum;
    private double initialUpdate;
    private double maxStep;
    private KernelNetworkTrain kernel;
    private int iteration;
    private final OpenCLTrainingProfile profile;

    public TrainFlatNetworkOpenCL(FlatNetwork network, EngineDataSet training, OpenCLTrainingProfile profile) {
        new ValidateForOpenCL().validate(network);
        if (!(training instanceof EngineIndexableSet)) {
            throw new EncogEngineError("Training data must be Indexable for this training type.");
        }
        if (EncogEngine.getInstance().getCL() == null) {
            throw new EncogEngineError("You must enable OpenCL before using this training type.");
        }
        this.profile = profile;
        this.network = network;
        this.training = (EngineIndexableSet)training;
    }

    private void callKernel(int start, int size, boolean learn, int iterations) {
        this.kernel.calculate(start, size, learn, iterations);
        double e = 0.0;
        for (int i = 0; i < this.kernel.getGlobalWork(); ++i) {
            e += (double)this.kernel.getErrors()[i];
        }
        this.error += e;
    }

    @Override
    public void finishTraining() {
        if (this.kernel != null) {
            this.kernel.release();
        }
    }

    @Override
    public double getError() {
        return this.error;
    }

    @Override
    public int getIteration() {
        return this.iteration;
    }

    public double[] getLastGradient() {
        double[] result = new double[this.network.getWeights().length];
        for (int i = 0; i < result.length; ++i) {
            result[i] = this.kernel.getTempDataArray()[i];
        }
        return result;
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public int getLearningType() {
        return this.learningType;
    }

    public double getMaxStep() {
        return this.maxStep;
    }

    public double getMomentum() {
        return this.momentum;
    }

    @Override
    public FlatNetwork getNetwork() {
        return this.network;
    }

    @Override
    public int getNumThreads() {
        return 0;
    }

    private Map<String, String> getOptions(String learningType) {
        HashMap<String, String> options = new HashMap<String, String>();
        options.put("NEURON_COUNT", "" + this.network.getNeuronCount());
        options.put("WEIGHT_COUNT", "" + this.network.getWeights().length);
        options.put(learningType, null);
        return options;
    }

    @Override
    public EngineDataSet getTraining() {
        return null;
    }

    public double[] getUpdateValues() {
        double[] result = new double[this.network.getWeights().length];
        int len = this.network.getWeights().length;
        for (int i = 0; i < result.length; ++i) {
            result[i] = this.kernel.getTempDataArray()[len + i];
        }
        return result;
    }

    @Override
    public void iteration() {
        this.iteration(1);
    }

    @Override
    public void iteration(int iterations) {
        if (this.learningType == -1) {
            throw new EncogEngineError("Learning type has not been defined yet, you must first call one of the learnXXXX methods, such as learnRPROP.");
        }
        this.iteration += iterations;
        int currentIndex = 0;
        this.error = 0.0;
        int count = this.profile.getKernelNumberOfCalls();
        if (count > 0 && iterations > 1) {
            throw new EncogEngineError("Must use an OpenCL ratio of 1.0 if you are going to use an iteration count > 1.");
        }
        this.kernel.setGlobalWork(this.profile.getKernelGlobalWorkgroup());
        this.kernel.setLocalWork(this.profile.getKernelLocalWorkgroup());
        while (count > 0) {
            this.callKernel(currentIndex, this.profile.getKernelWorkPerCall(), false, 1);
            --count;
            currentIndex += this.profile.getKernelWorkPerCall() * this.kernel.getGlobalWork();
        }
        this.kernel.setGlobalWork(this.profile.getKernelRemainderGlobal());
        this.kernel.setLocalWork(this.profile.getKernelRemainderGlobal());
        this.callKernel(currentIndex, this.profile.getKernelRemainderPer(), true, iterations);
        count = (int)this.training.getRecordCount();
        this.error /= (double)(count * this.training.getIdealSize());
        if (ErrorCalculation.getMode() == ErrorCalculationMode.RMS) {
            this.error = Math.sqrt(this.error);
        }
        EngineArray.arrayCopy(this.kernel.getWeightOutArray(), this.network.getWeights());
    }

    public void learnBPROP(double learningRate, double momentum) {
        this.learningType = 1;
        this.momentum = momentum;
        this.learningRate = learningRate;
        this.learningType = 1;
        Map<String, String> options = this.getOptions("LEARN_BPROP");
        this.kernel = new KernelNetworkTrain(this.profile.getDevice(), this.network, this.training, this.network.getWeights().length + 2);
        this.kernel.compile(options, this.profile, this.network);
        this.kernel.getTempDataArray()[0] = (float)learningRate;
        this.kernel.getTempDataArray()[1] = (float)momentum;
    }

    public void learnManhattan(double learningRate) {
        this.learningType = 2;
        this.learningRate = learningRate;
        Map<String, String> options = this.getOptions("LEARN_MANHATTAN");
        this.kernel = new KernelNetworkTrain(this.profile.getDevice(), this.network, this.training, 1);
        this.kernel.compile(options, this.profile, this.network);
        this.kernel.getTempDataArray()[0] = (float)learningRate;
    }

    public void learnRPROP() {
        this.learnRPROP(0.1, 50.0);
    }

    public void learnRPROP(double initialUpdate, double maxStep) {
        this.learningType = 0;
        this.initialUpdate = initialUpdate;
        this.maxStep = maxStep;
        Map<String, String> options = this.getOptions("LEARN_RPROP");
        this.kernel = new KernelNetworkTrain(this.profile.getDevice(), this.network, this.training, this.network.getWeights().length * 2);
        this.kernel.compile(options, this.profile, this.network);
        int weightLength = this.network.getWeights().length;
        for (int i = 0; i < weightLength; ++i) {
            this.kernel.getTempDataArray()[i] = 0.0f;
            this.kernel.getTempDataArray()[i + weightLength] = (float)this.initialUpdate;
        }
    }

    @Override
    public void setIteration(int iteration) {
        this.iteration = iteration;
    }

    @Override
    public void setNumThreads(int numThreads) {
    }
}

