/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.instructions.gpu.context;

import jcuda.Pointer;
import jcuda.jcublas.JCublas2;
import jcuda.jcublas.cublasHandle;
import jcuda.jcudnn.JCudnn;
import jcuda.jcudnn.cudnnHandle;
import jcuda.jcusolver.JCusolverDn;
import jcuda.jcusolver.cusolverDnHandle;
import jcuda.jcusparse.JCusparse;
import jcuda.jcusparse.cusparseHandle;
import jcuda.runtime.JCuda;
import jcuda.runtime.cudaDeviceProp;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool;
import org.apache.sysml.runtime.instructions.gpu.context.GPUMemoryManager;
import org.apache.sysml.runtime.instructions.gpu.context.GPUObject;
import org.apache.sysml.runtime.instructions.gpu.context.JCudaKernels;
import org.apache.sysml.utils.GPUStatistics;

public class GPUContext {
    protected static final Log LOG = LogFactory.getLog(GPUContext.class.getName());
    final int MAJOR_REQUIRED = 3;
    final int MINOR_REQUIRED = 0;
    private final int deviceNum;
    private cudnnHandle cudnnHandle;
    private cublasHandle cublasHandle;
    private cusparseHandle cusparseHandle;
    private cusolverDnHandle cusolverDnHandle;
    private JCudaKernels kernels;
    private GPUMemoryManager memoryManager;

    public GPUMemoryManager getMemoryManager() {
        return this.memoryManager;
    }

    protected GPUContext(int deviceNum) throws DMLRuntimeException {
        this.deviceNum = deviceNum;
        JCuda.cudaSetDevice((int)deviceNum);
        JCuda.cudaSetDeviceFlags((int)4);
        long start = -1L;
        if (DMLScript.STATISTICS) {
            start = System.nanoTime();
        }
        this.initializeCudaLibraryHandles();
        if (DMLScript.STATISTICS) {
            GPUStatistics.cudaLibrariesInitTime = System.nanoTime() - start;
        }
        this.memoryManager = new GPUMemoryManager(this);
    }

    public static int cudaGetDevice() {
        int[] device = new int[1];
        JCuda.cudaGetDevice((int[])device);
        return device[0];
    }

    public void printMemoryInfo(String opcode) throws DMLRuntimeException {
        if (LOG.isDebugEnabled()) {
            LOG.debug(opcode + ": " + this.memoryManager.toString());
        }
    }

    private void initializeCudaLibraryHandles() throws DMLRuntimeException {
        this.deleteCudaLibraryHandles();
        if (this.cudnnHandle == null) {
            this.cudnnHandle = new cudnnHandle();
            JCudnn.cudnnCreate((cudnnHandle)this.cudnnHandle);
        }
        if (this.cublasHandle == null) {
            this.cublasHandle = new cublasHandle();
            JCublas2.cublasCreate((cublasHandle)this.cublasHandle);
        }
        if (this.cusparseHandle == null) {
            this.cusparseHandle = new cusparseHandle();
            JCusparse.cusparseCreate((cusparseHandle)this.cusparseHandle);
        }
        if (this.cusolverDnHandle == null) {
            this.cusolverDnHandle = new cusolverDnHandle();
            JCusolverDn.cusolverDnCreate((cusolverDnHandle)this.cusolverDnHandle);
        }
        if (this.kernels == null) {
            this.kernels = new JCudaKernels();
        }
    }

    public int getDeviceNum() {
        return this.deviceNum;
    }

    public void initializeThread() throws DMLRuntimeException {
        JCuda.cudaSetDevice((int)this.deviceNum);
        this.initializeCudaLibraryHandles();
    }

    public Pointer allocate(long size) throws DMLRuntimeException {
        return this.memoryManager.malloc(null, size);
    }

    public Pointer allocate(String instructionName, long size) throws DMLRuntimeException {
        return this.memoryManager.malloc(instructionName, size);
    }

    public void cudaFreeHelper(Pointer toFree) throws DMLRuntimeException {
        this.cudaFreeHelper(null, toFree, DMLScript.EAGER_CUDA_FREE);
    }

    public void cudaFreeHelper(Pointer toFree, boolean eager) throws DMLRuntimeException {
        this.cudaFreeHelper(null, toFree, eager);
    }

    public void cudaFreeHelper(String instructionName, Pointer toFree) throws DMLRuntimeException {
        this.cudaFreeHelper(instructionName, toFree, DMLScript.EAGER_CUDA_FREE);
    }

    public void cudaFreeHelper(String instructionName, Pointer toFree, boolean eager) throws DMLRuntimeException {
        this.memoryManager.free(instructionName, toFree, eager);
    }

    public long getAvailableMemory() {
        return this.memoryManager.getAvailableMemory();
    }

    public void ensureComputeCapability() throws DMLRuntimeException {
        int[] devices = new int[]{-1};
        JCuda.cudaGetDeviceCount((int[])devices);
        if (devices[0] == -1) {
            throw new DMLRuntimeException("Call to cudaGetDeviceCount returned 0 devices");
        }
        boolean isComputeCapable = true;
        for (int i = 0; i < devices[0]; ++i) {
            cudaDeviceProp properties = GPUContextPool.getGPUProperties(i);
            int major = properties.major;
            int minor = properties.minor;
            if (major < 3) {
                isComputeCapable = false;
                continue;
            }
            if (major != 3 || minor >= 0) continue;
            isComputeCapable = false;
        }
        if (!isComputeCapable) {
            throw new DMLRuntimeException("One of the CUDA cards on the system has compute capability lower than 3.0");
        }
    }

    public GPUObject createGPUObject(MatrixObject mo) {
        GPUObject ret = new GPUObject(this, mo);
        this.getMemoryManager().addGPUObject(ret);
        return ret;
    }

    public cudaDeviceProp getGPUProperties() throws DMLRuntimeException {
        return GPUContextPool.getGPUProperties(this.deviceNum);
    }

    public int getMaxThreadsPerBlock() throws DMLRuntimeException {
        cudaDeviceProp deviceProps = this.getGPUProperties();
        return deviceProps.maxThreadsPerBlock;
    }

    public int getMaxBlocks() throws DMLRuntimeException {
        cudaDeviceProp deviceProp = this.getGPUProperties();
        return deviceProp.maxGridSize[0];
    }

    public long getMaxSharedMemory() throws DMLRuntimeException {
        cudaDeviceProp deviceProp = this.getGPUProperties();
        return deviceProp.sharedMemPerBlock;
    }

    public int getWarpSize() throws DMLRuntimeException {
        cudaDeviceProp deviceProp = this.getGPUProperties();
        return deviceProp.warpSize;
    }

    public cudnnHandle getCudnnHandle() {
        return this.cudnnHandle;
    }

    public cublasHandle getCublasHandle() {
        return this.cublasHandle;
    }

    public cusparseHandle getCusparseHandle() {
        return this.cusparseHandle;
    }

    public cusolverDnHandle getCusolverDnHandle() {
        return this.cusolverDnHandle;
    }

    public JCudaKernels getKernels() {
        return this.kernels;
    }

    public void destroy() throws DMLRuntimeException {
        if (LOG.isTraceEnabled()) {
            LOG.trace("GPU : this context was destroyed, this = " + this.toString());
        }
        this.clearMemory();
        this.deleteCudaLibraryHandles();
    }

    private void deleteCudaLibraryHandles() {
        if (this.cudnnHandle != null) {
            JCudnn.cudnnDestroy((cudnnHandle)this.cudnnHandle);
        }
        if (this.cublasHandle != null) {
            JCublas2.cublasDestroy((cublasHandle)this.cublasHandle);
        }
        if (this.cusparseHandle != null) {
            JCusparse.cusparseDestroy((cusparseHandle)this.cusparseHandle);
        }
        if (this.cusolverDnHandle != null) {
            JCusolverDn.cusolverDnDestroy((cusolverDnHandle)this.cusolverDnHandle);
        }
        this.cudnnHandle = null;
        this.cublasHandle = null;
        this.cusparseHandle = null;
        this.cusolverDnHandle = null;
    }

    public void clearMemory() throws DMLRuntimeException {
        this.memoryManager.clearMemory();
    }

    public void clearTemporaryMemory() throws DMLRuntimeException {
        this.memoryManager.clearTemporaryMemory();
    }

    public String toString() {
        return "GPUContext{deviceNum=" + this.deviceNum + '}';
    }
}

