/*
 * Decompiled with CFR 0.152.
 */
package org.encog.engine.opencl.kernels;

import java.util.HashMap;
import org.encog.engine.data.BasicEngineData;
import org.encog.engine.data.EngineData;
import org.encog.engine.data.EngineIndexableSet;
import org.encog.engine.network.activation.ActivationFunction;
import org.encog.engine.network.flat.FlatNetwork;
import org.encog.engine.opencl.EncogCLDevice;
import org.encog.engine.opencl.EncogCLQueue;
import org.encog.engine.opencl.exceptions.OpenCLError;
import org.encog.engine.opencl.exceptions.OutOfOpenCLResources;
import org.encog.engine.opencl.kernels.EncogKernel;
import org.encog.engine.util.EngineArray;
import org.encog.engine.util.ErrorCalculation;
import org.encog.engine.util.ResourceLoader;
import org.jocl.CLException;
import org.jocl.cl_mem;

public class KernelNetworkCalc
extends EncogKernel {
    public static final int PARRAY_INPUT_COUNT = 0;
    public static final int PARRAY_OUTPUT_COUNT = 1;
    public static final int PARRAY_LAYER_COUNT = 2;
    public static final int PARRAY_LEARN = 3;
    public static final int PARRAY_START = 4;
    public static final int PARRAY_ITEMS_PER = 5;
    public static final int PARRAY_ITERATIONS = 6;
    private cl_mem weightInArrayBuffer;
    private cl_mem layerIndexBuffer;
    private cl_mem layerCountBuffer;
    private cl_mem layerFeedCountBuffer;
    private cl_mem weightIndexBuffer;
    private float[] weightInArray;
    private float[] inputArray;
    private float[] idealArray;
    private cl_mem inputBuffer;
    private cl_mem layerOutputBuffer;
    private cl_mem idealBuffer;
    private float[] layerOutput;
    private int[] paramArray;
    private cl_mem paramBuffer;
    private cl_mem errorBuffer;
    private FlatNetwork flat;
    private float[] errors;
    private EngineIndexableSet training;
    private final EncogCLDevice device;
    private int trainingLength;

    public KernelNetworkCalc(EncogCLDevice device) {
        super(device, "org/encog/engine/resources/KernelNetCalc.txt", "NetworkCalc");
        this.device = device;
        this.paramArray = new int[10];
        this.paramBuffer = this.createArrayReadOnly(this.paramArray);
    }

    public void calculate(int start, int size) {
        this.prepareKernel();
        this.paramArray[4] = start;
        this.paramArray[5] = size;
        this.setGlobalWork(size);
        this.setLocalWork(64);
        EngineArray.arrayCopy(this.flat.getWeights(), this.weightInArray);
        this.setArg(0, this.paramBuffer);
        this.setArg(1, this.errorBuffer);
        this.setArg(2, this.layerIndexBuffer);
        this.setArg(3, this.layerCountBuffer);
        this.setArg(4, this.layerFeedCountBuffer);
        this.setArg(5, this.weightIndexBuffer);
        this.setArg(6, this.inputBuffer);
        this.setArg(7, this.idealBuffer);
        this.setArg(8, this.weightInArrayBuffer);
        this.setArg(9, this.layerOutputBuffer);
        try {
            EncogCLQueue queue = this.device.getQueue();
            this.paramArray[4] = start;
            queue.array2Buffer(this.weightInArray, this.weightInArrayBuffer);
            queue.array2Buffer(this.paramArray, this.paramBuffer);
            queue.execute(this);
            queue.waitFinish();
            queue.buffer2Array(this.errorBuffer, this.errors);
            queue.buffer2Array(this.layerOutputBuffer, this.layerOutput);
        }
        catch (CLException e) {
            if (e.getMessage().equals("CL_OUT_OF_RESOURCES")) {
                throw new OutOfOpenCLResources(e);
            }
            throw new OpenCLError(e);
        }
        catch (Exception e) {
            throw new OpenCLError(e);
        }
    }

    public void compile(FlatNetwork network) {
        ActivationFunction activation = network.getActivationFunctions()[0];
        StringBuilder source = new StringBuilder();
        source.append("#define ACTIVATION(x,slope)");
        source.append(activation.getOpenCLExpression(false));
        source.append("\r\n");
        source.append(ResourceLoader.loadString(this.getSourceName()));
        this.setCLSource(source.toString());
        HashMap<String, String> options = new HashMap<String, String>();
        options.put("NEURON_COUNT", "" + network.getNeuronCount());
        options.put("WEIGHT_COUNT", "" + network.getWeights().length);
        this.compile(options);
    }

    public float[] getErrors() {
        return this.errors;
    }

    @Override
    public void release() {
        super.release();
        if (this.errorBuffer != null) {
            this.releaseBuffer(this.errorBuffer);
            this.errorBuffer = null;
        }
        if (this.idealBuffer != null) {
            this.releaseBuffer(this.idealBuffer);
            this.idealBuffer = null;
        }
        if (this.inputBuffer != null) {
            this.releaseBuffer(this.inputBuffer);
            this.inputBuffer = null;
        }
        if (this.layerCountBuffer != null) {
            this.releaseBuffer(this.layerCountBuffer);
            this.layerCountBuffer = null;
        }
        if (this.layerFeedCountBuffer != null) {
            this.releaseBuffer(this.layerFeedCountBuffer);
            this.layerFeedCountBuffer = null;
        }
        if (this.layerIndexBuffer != null) {
            this.releaseBuffer(this.layerIndexBuffer);
            this.layerIndexBuffer = null;
        }
        if (this.paramBuffer != null) {
            this.releaseBuffer(this.paramBuffer);
            this.paramBuffer = null;
        }
        if (this.weightInArrayBuffer != null) {
            this.releaseBuffer(this.weightInArrayBuffer);
            this.weightInArrayBuffer = null;
        }
        if (this.weightIndexBuffer != null) {
            this.releaseBuffer(this.weightIndexBuffer);
            this.weightIndexBuffer = null;
        }
    }

    public FlatNetwork getFlat() {
        return this.flat;
    }

    public void setFlat(FlatNetwork flat) {
        this.flat = flat;
        this.weightInArray = new float[flat.getWeights().length];
        int inputSize = flat.getInputCount();
        int idealSize = flat.getOutputCount();
        this.paramArray[0] = this.flat.getInputCount();
        this.paramArray[1] = this.flat.getOutputCount();
        this.paramArray[2] = this.flat.getLayerCounts().length;
        if (this.layerCountBuffer != null) {
            this.releaseBuffer(this.layerCountBuffer);
            this.layerCountBuffer = null;
        }
        if (this.layerFeedCountBuffer != null) {
            this.releaseBuffer(this.layerFeedCountBuffer);
            this.layerFeedCountBuffer = null;
        }
        if (this.layerIndexBuffer != null) {
            this.releaseBuffer(this.layerIndexBuffer);
            this.layerIndexBuffer = null;
        }
        if (this.weightInArrayBuffer != null) {
            this.releaseBuffer(this.weightInArrayBuffer);
            this.weightInArrayBuffer = null;
        }
        if (this.weightIndexBuffer != null) {
            this.releaseBuffer(this.weightIndexBuffer);
            this.weightIndexBuffer = null;
        }
        this.layerIndexBuffer = this.createArrayReadOnly(this.flat.getLayerIndex());
        this.layerCountBuffer = this.createArrayReadOnly(this.flat.getLayerCounts());
        this.layerFeedCountBuffer = this.createArrayReadOnly(this.flat.getLayerFeedCounts());
        this.weightInArrayBuffer = this.createArrayReadOnly(this.weightInArray);
        this.weightIndexBuffer = this.createArrayReadOnly(this.flat.getWeightIndex());
        this.allocateCommon();
        this.compile(flat);
    }

    private void allocateCommon() {
        if (this.training != null && this.flat != null) {
            if (this.layerOutputBuffer != null) {
                this.releaseBuffer(this.layerOutputBuffer);
                this.layerOutputBuffer = null;
            }
            this.layerOutput = new float[this.flat.getLayerOutput().length * this.trainingLength];
            this.layerOutputBuffer = this.createFloatArrayWriteOnly(this.layerOutput.length);
        }
    }

    public EngineIndexableSet getTraining() {
        return this.training;
    }

    public void setTraining(EngineIndexableSet training) {
        this.training = training;
        this.trainingLength = (int)this.training.getRecordCount();
        EngineData pair = BasicEngineData.createPair(this.flat.getInputCount(), this.flat.getOutputCount());
        this.inputArray = new float[training.getInputSize() * this.trainingLength];
        this.idealArray = new float[training.getIdealSize() * this.trainingLength];
        int inputIndex = 0;
        int idealIndex = 0;
        for (int i = 0; i < this.trainingLength; ++i) {
            int col;
            training.getRecord(i, pair);
            for (col = 0; col < this.flat.getInputCount(); ++col) {
                this.inputArray[inputIndex++] = (float)pair.getInputArray()[col];
            }
            for (col = 0; col < this.flat.getOutputCount(); ++col) {
                this.idealArray[idealIndex++] = (float)pair.getIdealArray()[col];
            }
        }
        int errorSize = (int)training.getRecordCount();
        this.errors = new float[errorSize];
        if (this.errorBuffer != null) {
            this.releaseBuffer(this.errorBuffer);
            this.errorBuffer = null;
        }
        if (this.idealBuffer != null) {
            this.releaseBuffer(this.idealBuffer);
            this.idealBuffer = null;
        }
        if (this.inputBuffer != null) {
            this.releaseBuffer(this.inputBuffer);
            this.inputBuffer = null;
        }
        this.errorBuffer = this.createFloatArrayWriteOnly(errorSize);
        this.inputBuffer = this.createArrayReadOnly(this.inputArray);
        this.idealBuffer = this.createArrayReadOnly(this.idealArray);
        this.allocateCommon();
    }

    public double getError() {
        ErrorCalculation ec = new ErrorCalculation();
        double result = 0.0;
        for (int i = 0; i < this.errors.length; ++i) {
            result += (double)this.errors[i];
        }
        return result / (double)(this.errors.length * this.flat.getOutputCount());
    }
}

