/*
 * Decompiled with CFR 0.152.
 */
package org.encog.engine.opencl.kernels;

import java.util.Map;
import org.encog.engine.data.BasicEngineData;
import org.encog.engine.data.EngineData;
import org.encog.engine.data.EngineIndexableSet;
import org.encog.engine.network.activation.ActivationFunction;
import org.encog.engine.network.flat.FlatNetwork;
import org.encog.engine.network.train.prop.OpenCLTrainingProfile;
import org.encog.engine.opencl.EncogCLDevice;
import org.encog.engine.opencl.EncogCLQueue;
import org.encog.engine.opencl.exceptions.OpenCLError;
import org.encog.engine.opencl.exceptions.OutOfOpenCLResources;
import org.encog.engine.opencl.kernels.EncogKernel;
import org.encog.engine.util.EngineArray;
import org.encog.engine.util.ResourceLoader;
import org.jocl.CLException;
import org.jocl.cl_mem;

public class KernelNetworkTrain
extends EncogKernel {
    public static final int PARRAY_INPUT_COUNT = 0;
    public static final int PARRAY_OUTPUT_COUNT = 1;
    public static final int PARRAY_LAYER_COUNT = 2;
    public static final int PARRAY_LEARN = 3;
    public static final int PARRAY_START = 4;
    public static final int PARRAY_ITEMS_PER = 5;
    public static final int PARRAY_ITERATIONS = 6;
    private cl_mem weightInArrayBuffer;
    private cl_mem weightOutArrayBuffer;
    private cl_mem layerIndexBuffer;
    private cl_mem layerCountBuffer;
    private cl_mem layerFeedCountBuffer;
    private cl_mem weightIndexBuffer;
    private cl_mem activationTypeBuffer;
    private cl_mem tempDataInBuffer;
    private cl_mem tempDataOutBuffer;
    private final float[] weightInArray;
    private final float[] weightOutArray;
    private float[] tempDataArray;
    private int layerDeltaSize;
    private final float[] inputArray;
    private final float[] idealArray;
    private cl_mem inputBuffer;
    private cl_mem idealBuffer;
    private final int[] paramArray;
    private cl_mem paramBuffer;
    private cl_mem errorBuffer;
    private cl_mem gradientOutBuffer;
    private cl_mem gradientInBuffer;
    private final FlatNetwork flat;
    private float[] errors;
    private final float[] gradients;
    private final EngineIndexableSet training;
    private final EncogCLDevice device;
    private final int trainingLength;

    public KernelNetworkTrain(EncogCLDevice device, FlatNetwork flat, EngineIndexableSet training, int tempDataSize) {
        super(device, "org/encog/engine/resources/KernelNetTrain.txt", "NetworkTrain");
        this.training = training;
        this.trainingLength = (int)this.training.getRecordCount();
        this.device = device;
        this.flat = flat;
        this.weightInArray = new float[flat.getWeights().length];
        this.weightOutArray = new float[flat.getWeights().length];
        this.tempDataArray = new float[tempDataSize];
        this.gradients = new float[flat.getWeights().length];
        this.layerDeltaSize = 0;
        for (int i = 0; i < flat.getLayerCounts().length; ++i) {
            this.layerDeltaSize += flat.getLayerCounts()[i];
        }
        int inputSize = flat.getInputCount();
        int idealSize = flat.getOutputCount();
        this.inputArray = new float[inputSize * this.trainingLength];
        this.idealArray = new float[idealSize * this.trainingLength];
        this.paramArray = new int[10];
        EngineData pair = BasicEngineData.createPair(flat.getInputCount(), flat.getOutputCount());
        int inputIndex = 0;
        int idealIndex = 0;
        for (int i = 0; i < this.trainingLength; ++i) {
            int col;
            training.getRecord(i, pair);
            for (col = 0; col < flat.getInputCount(); ++col) {
                this.inputArray[inputIndex++] = (float)pair.getInputArray()[col];
            }
            for (col = 0; col < flat.getOutputCount(); ++col) {
                this.idealArray[idealIndex++] = (float)pair.getIdealArray()[col];
            }
        }
    }

    public void assignWorkgroupSizes(int trainingSize, int requestedGlobalSize) {
        int threads = Math.min(trainingSize, requestedGlobalSize);
        this.setLocalWork(Math.min(this.getMaxWorkGroupSize(), threads));
        this.setGlobalWork(threads);
    }

    public void calculate(int start, int size, boolean learn, int iterations) {
        this.prepareKernel();
        this.paramArray[3] = learn ? 1 : 0;
        this.paramArray[4] = start;
        this.paramArray[5] = size;
        this.paramArray[6] = iterations;
        EngineArray.arrayCopy(this.flat.getWeights(), this.weightInArray);
        this.setArg(0, this.paramBuffer);
        this.setArg(1, this.errorBuffer);
        this.setArg(2, this.layerIndexBuffer);
        this.setArg(3, this.layerCountBuffer);
        this.setArg(4, this.layerFeedCountBuffer);
        this.setArg(5, this.weightIndexBuffer);
        this.setArg(6, this.inputBuffer);
        this.setArg(7, this.idealBuffer);
        this.setArg(8, this.weightInArrayBuffer);
        this.setArg(9, this.weightOutArrayBuffer);
        this.setArg(10, this.gradientOutBuffer);
        this.setArg(11, this.activationTypeBuffer);
        this.setArg(12, this.tempDataInBuffer);
        this.setArg(13, this.tempDataOutBuffer);
        this.setArg(14, this.gradientInBuffer);
        try {
            EncogCLQueue queue = this.device.getQueue();
            EngineArray.fill(this.gradients, 0.0f);
            this.paramArray[3] = learn ? 1 : 0;
            this.paramArray[4] = start;
            queue.array2Buffer(this.weightInArray, this.weightInArrayBuffer);
            queue.array2Buffer(this.tempDataArray, this.tempDataInBuffer);
            queue.array2Buffer(this.gradients, this.gradientInBuffer);
            queue.array2Buffer(this.paramArray, this.paramBuffer);
            queue.execute(this);
            queue.waitFinish();
            queue.buffer2Array(this.errorBuffer, this.errors);
            queue.buffer2Array(this.weightOutArrayBuffer, this.weightOutArray);
            queue.buffer2Array(this.tempDataOutBuffer, this.tempDataArray);
            queue.buffer2Array(this.gradientOutBuffer, this.gradients);
        }
        catch (CLException e) {
            if (e.getMessage().equals("CL_OUT_OF_RESOURCES")) {
                throw new OutOfOpenCLResources(e);
            }
            throw new OpenCLError(e);
        }
        catch (Exception e) {
            throw new OpenCLError(e);
        }
    }

    public void compile(Map<String, String> options, OpenCLTrainingProfile profile, FlatNetwork network) {
        ActivationFunction activation = network.getActivationFunctions()[0];
        StringBuilder source = new StringBuilder();
        source.append("#define ACTIVATION(x,slope)");
        source.append(activation.getOpenCLExpression(false));
        source.append("\r\n");
        source.append("#define DERIVATIVE(x,slope)");
        source.append(activation.getOpenCLExpression(true));
        source.append("\r\n");
        source.append(ResourceLoader.loadString(this.getSourceName()));
        this.setCLSource(source.toString());
        this.compile(options);
        profile.calculateKernelParams(this, this.training);
        this.init(profile);
    }

    public float[] getErrors() {
        return this.errors;
    }

    public float[] getTempDataArray() {
        return this.tempDataArray;
    }

    public float[] getWeightOutArray() {
        return this.weightOutArray;
    }

    public void init(OpenCLTrainingProfile profile) {
        int errorSize = profile.getKernelGlobalWorkgroup();
        int gradientSize = profile.getKernelGlobalWorkgroup() * this.flat.getWeights().length;
        this.errors = new float[errorSize];
        this.paramArray[0] = this.flat.getInputCount();
        this.paramArray[1] = this.flat.getOutputCount();
        this.paramArray[2] = this.flat.getLayerCounts().length;
        this.inputBuffer = this.createArrayReadOnly(this.inputArray);
        this.idealBuffer = this.createArrayReadOnly(this.idealArray);
        this.errorBuffer = this.createFloatArrayWriteOnly(errorSize);
        this.gradientOutBuffer = this.createFloatArrayWriteOnly(gradientSize);
        this.gradientInBuffer = this.createArrayReadOnly(this.gradients);
        this.paramBuffer = this.createArrayReadOnly(this.paramArray);
        this.layerIndexBuffer = this.createArrayReadOnly(this.flat.getLayerIndex());
        this.layerCountBuffer = this.createArrayReadOnly(this.flat.getLayerCounts());
        this.layerFeedCountBuffer = this.createArrayReadOnly(this.flat.getLayerFeedCounts());
        this.weightInArrayBuffer = this.createArrayReadOnly(this.weightInArray);
        this.weightOutArrayBuffer = this.createFloatArrayWriteOnly(this.weightInArray.length);
        this.weightIndexBuffer = this.createArrayReadOnly(this.flat.getWeightIndex());
        this.activationTypeBuffer = this.createArrayReadOnly(this.flat.getLayerCounts());
        this.tempDataInBuffer = this.createArrayReadOnly(this.tempDataArray);
        this.tempDataOutBuffer = this.createFloatArrayWriteOnly(this.tempDataArray.length);
    }

    @Override
    public void release() {
        super.release();
        this.releaseBuffer(this.activationTypeBuffer);
        this.releaseBuffer(this.errorBuffer);
        this.releaseBuffer(this.gradientOutBuffer);
        this.releaseBuffer(this.gradientInBuffer);
        this.releaseBuffer(this.idealBuffer);
        this.releaseBuffer(this.inputBuffer);
        this.releaseBuffer(this.layerCountBuffer);
        this.releaseBuffer(this.layerFeedCountBuffer);
        this.releaseBuffer(this.layerIndexBuffer);
        this.releaseBuffer(this.paramBuffer);
        this.releaseBuffer(this.tempDataInBuffer);
        this.releaseBuffer(this.tempDataOutBuffer);
        this.releaseBuffer(this.weightInArrayBuffer);
        this.releaseBuffer(this.weightIndexBuffer);
        this.releaseBuffer(this.weightOutArrayBuffer);
    }

    public void setTempDataArray(float[] tempDataArray) {
        this.tempDataArray = tempDataArray;
    }
}

