/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.instructions.gpu.context;

import java.util.ArrayList;
import java.util.concurrent.atomic.AtomicLong;
import jcuda.driver.JCudaDriver;
import jcuda.jcublas.JCublas2;
import jcuda.jcublas.cublasHandle;
import jcuda.jcudnn.JCudnn;
import jcuda.jcudnn.cudnnHandle;
import jcuda.jcusparse.JCusparse;
import jcuda.jcusparse.cusparseHandle;
import jcuda.runtime.JCuda;
import jcuda.runtime.cudaDeviceProp;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.instructions.gpu.context.GPUContext;
import org.apache.sysml.runtime.instructions.gpu.context.GPUObject;
import org.apache.sysml.runtime.instructions.gpu.context.JCudaKernels;
import org.apache.sysml.runtime.matrix.data.LibMatrixCUDA;
import org.apache.sysml.utils.GPUStatistics;

public class JCudaContext
extends GPUContext {
    public static final Object syncObj = new Object();
    private static final Log LOG = LogFactory.getLog((String)JCudaContext.class.getName());
    public static ArrayList<GPUObject> allocatedPointers = new ArrayList();
    final int MAJOR_REQUIRED = 3;
    final int MINOR_REQUIRED = 0;
    public static int deviceCount = -1;
    public static boolean DEBUG = false;
    AtomicLong deviceMemBytes = new AtomicLong(0L);
    private static cudaDeviceProp[] deviceProperties;
    public double GPU_MEMORY_UTILIZATION_FACTOR = ConfigurationManager.getDMLConfig().getDoubleValue("gpu.memory.util.factor");
    public boolean REFRESH_AVAILABLE_MEMORY_EVERY_TIME = ConfigurationManager.getDMLConfig().getBooleanValue("gpu.memory.refresh");

    @Override
    public long getAvailableMemory() {
        if (this.REFRESH_AVAILABLE_MEMORY_EVERY_TIME) {
            long[] free = new long[]{0L};
            long[] total = new long[]{0L};
            if (JCuda.cudaMemGetInfo((long[])free, (long[])total) == 0) {
                this.deviceMemBytes.set(free[0]);
            } else {
                throw new RuntimeException("ERROR: Unable to get memory information of the GPU.");
            }
        }
        return (long)((double)this.deviceMemBytes.get() * this.GPU_MEMORY_UTILIZATION_FACTOR);
    }

    @Override
    public void ensureComputeCapability() throws DMLRuntimeException {
        int[] devices = new int[]{-1};
        JCuda.cudaGetDeviceCount((int[])devices);
        if (devices[0] == -1) {
            throw new DMLRuntimeException("Call to cudaGetDeviceCount returned 0 devices");
        }
        boolean isComputeCapable = true;
        for (int i = 0; i < devices[0]; ++i) {
            cudaDeviceProp properties = JCudaContext.getGPUProperties(i);
            int major = properties.major;
            int minor = properties.minor;
            if (major < 3) {
                isComputeCapable = false;
                continue;
            }
            if (major != 3 || minor >= 0) continue;
            isComputeCapable = false;
        }
        if (!isComputeCapable) {
            throw new DMLRuntimeException("One of the CUDA cards on the system has compute capability lower than 3.0");
        }
    }

    public static cudaDeviceProp getGPUProperties() {
        int[] device = new int[]{-1};
        JCuda.cudaGetDevice((int[])device);
        return JCudaContext.getGPUProperties(device[0]);
    }

    public static cudaDeviceProp getGPUProperties(int device) {
        if (deviceProperties[device] == null) {
            cudaDeviceProp properties = new cudaDeviceProp();
            JCuda.cudaGetDeviceProperties((cudaDeviceProp)properties, (int)device);
            JCudaContext.deviceProperties[device] = properties;
        }
        return deviceProperties[device];
    }

    public static int getMaxThreadsPerBlock() {
        cudaDeviceProp deviceProps = JCudaContext.getGPUProperties();
        return deviceProps.maxThreadsPerBlock;
    }

    public static int getMaxBlocks() {
        cudaDeviceProp deviceProp = JCudaContext.getGPUProperties();
        return deviceProp.maxGridSize[0];
    }

    public static long getMaxSharedMemory() {
        cudaDeviceProp deviceProp = JCudaContext.getGPUProperties();
        return deviceProp.sharedMemPerBlock;
    }

    public static int getWarpSize() {
        cudaDeviceProp deviceProp = JCudaContext.getGPUProperties();
        return deviceProp.warpSize;
    }

    public long getAndAddAvailableMemory(long v) {
        return this.deviceMemBytes.getAndAdd(v);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public JCudaContext() throws DMLRuntimeException {
        if (isGPUContextCreated.booleanValue()) {
            long startTime = System.currentTimeMillis();
            do {
                try {
                    Thread.sleep(100L);
                }
                catch (InterruptedException interruptedException) {
                    // empty catch block
                }
            } while (isGPUContextCreated.booleanValue() && System.currentTimeMillis() - startTime < 60000L);
            Boolean bl = isGPUContextCreated;
            synchronized (bl) {
                if (GPUContext.currContext != null) {
                    throw new RuntimeException("Cannot create multiple JCudaContext. Waited for 10 min to close previous GPUContext");
                }
            }
        }
        Boolean startTime = isGPUContextCreated;
        synchronized (startTime) {
            GPUContext.currContext = this;
        }
        long[] free = new long[]{0L};
        long[] total = new long[]{0L};
        long totalNumBytes = 0L;
        if (JCuda.cudaMemGetInfo((long[])free, (long[])total) != 0) {
            throw new RuntimeException("ERROR: Unable to get memory information of the GPU.");
        }
        totalNumBytes = total[0];
        this.deviceMemBytes.set(free[0]);
        LOG.info((Object)("Total GPU memory: " + (double)totalNumBytes * 1.0E-6 + " MB"));
        LOG.info((Object)("Available GPU memory: " + (double)this.deviceMemBytes.get() * 1.0E-6 + " MB"));
        long start = System.nanoTime();
        LibMatrixCUDA.cudnnHandle = new cudnnHandle();
        JCudnn.cudnnCreate((cudnnHandle)LibMatrixCUDA.cudnnHandle);
        LibMatrixCUDA.cublasHandle = new cublasHandle();
        JCublas2.cublasCreate((cublasHandle)LibMatrixCUDA.cublasHandle);
        LibMatrixCUDA.cusparseHandle = new cusparseHandle();
        JCusparse.cusparseCreate((cusparseHandle)LibMatrixCUDA.cusparseHandle);
        try {
            LibMatrixCUDA.kernels = new JCudaKernels();
        }
        catch (DMLRuntimeException e) {
            System.err.println("ERROR - Unable to initialize JCudaKernels. System in an inconsistent state");
            LibMatrixCUDA.kernels = null;
        }
        GPUStatistics.cudaLibrariesInitTime = System.nanoTime() - start;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void destroy() throws DMLRuntimeException {
        if (currContext != null) {
            Boolean bl = isGPUContextCreated;
            synchronized (bl) {
                JCudnn.cudnnDestroy((cudnnHandle)LibMatrixCUDA.cudnnHandle);
                JCublas2.cublasDestroy((cublasHandle)LibMatrixCUDA.cublasHandle);
                JCusparse.cusparseDestroy((cusparseHandle)LibMatrixCUDA.cusparseHandle);
                currContext = null;
                isGPUContextCreated = false;
            }
        } else if (LibMatrixCUDA.cudnnHandle != null || LibMatrixCUDA.cublasHandle != null) {
            throw new DMLRuntimeException("Error while destroying the GPUContext");
        }
    }

    static {
        long start = System.nanoTime();
        JCuda.setExceptionsEnabled((boolean)true);
        JCudnn.setExceptionsEnabled((boolean)true);
        JCublas2.setExceptionsEnabled((boolean)true);
        JCusparse.setExceptionsEnabled((boolean)true);
        JCudaDriver.setExceptionsEnabled((boolean)true);
        JCudaDriver.cuInit((int)0);
        int[] deviceCountArray = new int[]{0};
        JCudaDriver.cuDeviceGetCount((int[])deviceCountArray);
        deviceCount = deviceCountArray[0];
        deviceProperties = new cudaDeviceProp[deviceCount];
        LOG.info((Object)("Total number of GPUs on the machine: " + deviceCount));
        int maxBlocks = JCudaContext.getMaxBlocks();
        int maxThreadsPerBlock = JCudaContext.getMaxThreadsPerBlock();
        long sharedMemPerBlock = JCudaContext.getMaxSharedMemory();
        int[] device = new int[]{-1};
        JCuda.cudaGetDevice((int[])device);
        LOG.info((Object)("Active CUDA device number : " + device[0]));
        LOG.info((Object)("Max Blocks/Threads/SharedMem : " + maxBlocks + "/" + maxThreadsPerBlock + "/" + sharedMemPerBlock));
        GPUStatistics.cudaInitTime = System.nanoTime() - start;
    }
}

