/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.instructions.gpu.context;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import jcuda.NativePointerObject;
import jcuda.Pointer;
import jcuda.driver.CUcontext;
import jcuda.driver.CUdevice;
import jcuda.driver.CUfunction;
import jcuda.driver.CUmodule;
import jcuda.driver.CUresult;
import jcuda.driver.CUstream;
import jcuda.driver.JCudaDriver;
import jcuda.runtime.JCuda;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.instructions.gpu.context.ExecutionConfig;
import org.apache.sysml.runtime.io.IOUtilFunctions;

public class JCudaKernels {
    private static String ptxFileName = "/kernels/SystemML.ptx";
    private HashMap<String, CUfunction> kernels = new HashMap();
    private CUmodule module;

    public JCudaKernels() throws DMLRuntimeException {
        this.shutdown();
        JCudaKernels.initCUDA();
        this.module = new CUmodule();
        JCudaKernels.checkResult(JCudaDriver.cuModuleLoadDataEx((CUmodule)this.module, (Pointer)this.initKernels(ptxFileName), (int)0, (int[])new int[0], (Pointer)Pointer.to((int[])new int[0])));
    }

    private static void initCUDA() throws DMLRuntimeException {
        JCudaKernels.checkResult(JCudaDriver.cuInit((int)0));
        CUcontext context = new CUcontext();
        JCudaKernels.checkResult(JCudaDriver.cuCtxGetCurrent((CUcontext)context));
        CUcontext nullContext = new CUcontext();
        if (context.equals((Object)nullContext)) {
            JCudaKernels.createContext();
        }
    }

    private static void createContext() throws DMLRuntimeException {
        int deviceNumber = 0;
        CUdevice device = new CUdevice();
        JCudaKernels.checkResult(JCudaDriver.cuDeviceGet((CUdevice)device, (int)deviceNumber));
        CUcontext context = new CUcontext();
        JCudaKernels.checkResult(JCudaDriver.cuCtxCreate((CUcontext)context, (int)0, (CUdevice)device));
    }

    public void shutdown() {
        if (this.module != null) {
            JCudaDriver.cuModuleUnload((CUmodule)this.module);
        }
    }

    public void launchKernel(String name, ExecutionConfig config, Object ... arguments) throws DMLRuntimeException {
        CUfunction function = this.kernels.get(name);
        if (function == null) {
            function = new CUfunction();
            JCudaKernels.checkResult(JCudaDriver.cuModuleGetFunction((CUfunction)function, (CUmodule)this.module, (String)name));
        }
        Pointer[] kernelParams = new Pointer[arguments.length];
        for (int i = 0; i < arguments.length; ++i) {
            if (arguments[i] == null) {
                throw new DMLRuntimeException("The argument to the kernel cannot be null.");
            }
            if (arguments[i] instanceof Pointer) {
                kernelParams[i] = Pointer.to((NativePointerObject[])new NativePointerObject[]{(Pointer)arguments[i]});
                continue;
            }
            if (arguments[i] instanceof Integer) {
                kernelParams[i] = Pointer.to((int[])new int[]{(Integer)arguments[i]});
                continue;
            }
            if (arguments[i] instanceof Double) {
                kernelParams[i] = Pointer.to((double[])new double[]{(Double)arguments[i]});
                continue;
            }
            if (arguments[i] instanceof Long) {
                kernelParams[i] = Pointer.to((long[])new long[]{(Long)arguments[i]});
                continue;
            }
            if (arguments[i] instanceof Float) {
                kernelParams[i] = Pointer.to((float[])new float[]{((Float)arguments[i]).floatValue()});
                continue;
            }
            throw new DMLRuntimeException("The argument of type " + arguments[i].getClass() + " is not supported.");
        }
        JCudaKernels.checkResult(JCudaDriver.cuLaunchKernel((CUfunction)function, (int)config.gridDimX, (int)config.gridDimY, (int)config.gridDimZ, (int)config.blockDimX, (int)config.blockDimY, (int)config.blockDimZ, (int)config.sharedMemBytes, (CUstream)config.stream, (Pointer)Pointer.to((NativePointerObject[])kernelParams), null));
        JCuda.cudaDeviceSynchronize();
    }

    public static void checkResult(int cuResult) throws DMLRuntimeException {
        if (cuResult != 0) {
            throw new DMLRuntimeException(CUresult.stringFor((int)cuResult));
        }
    }

    private Pointer initKernels(String ptxFileName) throws DMLRuntimeException {
        ByteArrayOutputStream out;
        InputStream in;
        block6: {
            int read;
            in = null;
            out = null;
            in = JCudaKernels.class.getResourceAsStream(ptxFileName);
            if (in == null) break block6;
            out = new ByteArrayOutputStream();
            byte[] buffer = new byte[8192];
            while ((read = in.read(buffer)) != -1) {
                out.write(buffer, 0, read);
            }
            out.write(0);
            out.flush();
            Pointer pointer = Pointer.to((byte[])out.toByteArray());
            IOUtilFunctions.closeSilently(out);
            IOUtilFunctions.closeSilently(in);
            return pointer;
        }
        try {
            try {
                throw new DMLRuntimeException("The input file " + ptxFileName + " not found. (Hint: Please compile SystemML using -DenableGPU=true flag. Example: mvn package -DenableGPU=true).");
            }
            catch (IOException e) {
                throw new DMLRuntimeException("Could not initialize the kernels", e);
            }
        }
        catch (Throwable throwable) {
            IOUtilFunctions.closeSilently(out);
            IOUtilFunctions.closeSilently(in);
            throw throwable;
        }
    }
}

