/*
 * Decompiled with CFR 0.152.
 */
package jcuda.vec;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.LinkedHashMap;
import java.util.Map;
import jcuda.CudaException;
import jcuda.NativePointerObject;
import jcuda.Pointer;
import jcuda.driver.CUcontext;
import jcuda.driver.CUdevice;
import jcuda.driver.CUfunction;
import jcuda.driver.CUmodule;
import jcuda.driver.CUresult;
import jcuda.driver.CUstream;
import jcuda.driver.JCudaDriver;
import jcuda.vec.VecKernels;

final class DefaultVecKernels
implements VecKernels {
    private final CUmodule module;
    private final String kernelNamePrefix;
    private final String kernelNameSuffix;
    private final Map<String, CUfunction> functions;
    private int blockDimX;
    private static final int deviceNumber = 0;
    private CUstream stream;

    DefaultVecKernels(String kernelNameType, String kernelNamePrefix, String kernelNameSuffix) {
        this.kernelNamePrefix = kernelNamePrefix;
        this.kernelNameSuffix = kernelNameSuffix;
        DefaultVecKernels.initCUDA();
        this.blockDimX = DefaultVecKernels.getMaxBlockDimX();
        this.module = new CUmodule();
        String modelString = System.getProperty("sun.arch.data.model");
        int ccMajor = DefaultVecKernels.getComputeCapabilityMajor();
        String ccString = "20";
        if (ccMajor > 2) {
            ccString = "30";
        }
        String ptxFileName = "/kernels/JCudaVec_kernels_" + kernelNameType + "_" + modelString + "_cc" + ccString + ".ptx";
        byte[] ptxData = DefaultVecKernels.loadData(ptxFileName);
        DefaultVecKernels.checkResult(JCudaDriver.cuModuleLoadDataEx((CUmodule)this.module, (Pointer)Pointer.to((byte[])ptxData), (int)0, (int[])new int[0], (Pointer)Pointer.to((int[])new int[0])));
        this.functions = new LinkedHashMap<String, CUfunction>();
    }

    private static int getMaxBlockDimX() {
        CUdevice device = new CUdevice();
        DefaultVecKernels.checkResult(JCudaDriver.cuDeviceGet((CUdevice)device, (int)0));
        int[] maxBlockDimX = new int[]{0};
        JCudaDriver.cuDeviceGetAttribute((int[])maxBlockDimX, (int)2, (CUdevice)device);
        return maxBlockDimX[0];
    }

    private static int getComputeCapabilityMajor() {
        CUdevice device = new CUdevice();
        DefaultVecKernels.checkResult(JCudaDriver.cuDeviceGet((CUdevice)device, (int)0));
        int[] ccMajor = new int[]{0};
        JCudaDriver.cuDeviceGetAttribute((int[])ccMajor, (int)75, (CUdevice)device);
        return ccMajor[0];
    }

    private static void initCUDA() {
        DefaultVecKernels.checkResult(JCudaDriver.cuInit((int)0));
        CUcontext context = new CUcontext();
        DefaultVecKernels.checkResult(JCudaDriver.cuCtxGetCurrent((CUcontext)context));
        CUcontext nullContext = new CUcontext();
        if (context.equals((Object)nullContext)) {
            DefaultVecKernels.createContext();
        }
    }

    private static void createContext() {
        CUdevice device = new CUdevice();
        DefaultVecKernels.checkResult(JCudaDriver.cuDeviceGet((CUdevice)device, (int)0));
        CUcontext context = new CUcontext();
        DefaultVecKernels.checkResult(JCudaDriver.cuCtxCreate((CUcontext)context, (int)0, (CUdevice)device));
    }

    private static void checkResult(int cuResult) {
        if (cuResult != 0) {
            throw new CudaException(CUresult.stringFor((int)cuResult));
        }
    }

    private static byte[] loadData(String ptxFileName) {
        InputStream ptxInputStream = null;
        try {
            ptxInputStream = DefaultVecKernels.class.getResourceAsStream(ptxFileName);
            if (ptxInputStream != null) {
                byte[] byArray = DefaultVecKernels.loadData(ptxInputStream);
                return byArray;
            }
            throw new CudaException("Could not initialize the kernels: Resource " + ptxFileName + " not found");
        }
        finally {
            if (ptxInputStream != null) {
                try {
                    ptxInputStream.close();
                }
                catch (IOException e) {
                    throw new CudaException("Could not initialize the kernels", (Throwable)e);
                }
            }
        }
    }

    private static byte[] loadData(InputStream inputStream) {
        ByteArrayOutputStream baos = null;
        try {
            int read;
            baos = new ByteArrayOutputStream();
            byte[] buffer = new byte[8192];
            while ((read = inputStream.read(buffer)) != -1) {
                baos.write(buffer, 0, read);
            }
            baos.write(0);
            baos.flush();
            byte[] byArray = baos.toByteArray();
            return byArray;
        }
        catch (IOException e) {
            throw new CudaException("Could not load data", (Throwable)e);
        }
        finally {
            if (baos != null) {
                try {
                    baos.close();
                }
                catch (IOException e) {
                    throw new CudaException("Could not close output", (Throwable)e);
                }
            }
        }
    }

    @Override
    public void call(String name, long workSize, Object ... arguments) {
        CUfunction function = this.obtainFunction(name);
        Pointer kernelParameters = this.setupKernelParameters(arguments);
        this.callKernel(workSize, function, kernelParameters);
    }

    private CUfunction obtainFunction(String name) {
        CUfunction function = this.functions.get(name);
        if (function == null) {
            function = new CUfunction();
            DefaultVecKernels.checkResult(JCudaDriver.cuModuleGetFunction((CUfunction)function, (CUmodule)this.module, (String)(this.kernelNamePrefix + name + this.kernelNameSuffix)));
        }
        return function;
    }

    private Pointer setupKernelParameters(Object ... args) {
        Pointer[] kernelParameters = new Pointer[args.length];
        for (int i = 0; i < args.length; ++i) {
            Number value;
            Pointer pointer;
            Object arg = args[i];
            if (arg == null) {
                throw new NullPointerException("Argument " + i + " is null");
            }
            if (arg instanceof Pointer) {
                Pointer argPointer = (Pointer)arg;
                kernelParameters[i] = pointer = Pointer.to((NativePointerObject[])new NativePointerObject[]{argPointer});
                continue;
            }
            if (arg instanceof Byte) {
                value = (Byte)arg;
                kernelParameters[i] = pointer = Pointer.to((byte[])new byte[]{(Byte)value});
                continue;
            }
            if (arg instanceof Short) {
                value = (Short)arg;
                kernelParameters[i] = pointer = Pointer.to((short[])new short[]{(Short)value});
                continue;
            }
            if (arg instanceof Integer) {
                value = (Integer)arg;
                kernelParameters[i] = pointer = Pointer.to((int[])new int[]{(Integer)value});
                continue;
            }
            if (arg instanceof Long) {
                value = (Long)arg;
                kernelParameters[i] = pointer = Pointer.to((long[])new long[]{(Long)value});
                continue;
            }
            if (arg instanceof Float) {
                value = (Float)arg;
                kernelParameters[i] = pointer = Pointer.to((float[])new float[]{((Float)value).floatValue()});
                continue;
            }
            if (arg instanceof Double) {
                value = (Double)arg;
                kernelParameters[i] = pointer = Pointer.to((double[])new double[]{(Double)value});
                continue;
            }
            throw new CudaException("Type " + arg.getClass() + " may not be passed to a function");
        }
        return Pointer.to((NativePointerObject[])kernelParameters);
    }

    private void callKernel(long workSize, CUfunction function, Pointer kernelParameters) {
        int gridDimX = (int)Math.ceil((double)workSize / (double)this.blockDimX);
        DefaultVecKernels.checkResult(JCudaDriver.cuLaunchKernel((CUfunction)function, (int)gridDimX, (int)1, (int)1, (int)this.blockDimX, (int)1, (int)1, (int)0, (CUstream)this.stream, (Pointer)kernelParameters, null));
    }

    @Override
    public void shutdown() {
        JCudaDriver.cuModuleUnload((CUmodule)this.module);
    }
}

