/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.instructions.gpu.context;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.atomic.LongAdder;
import jcuda.CudaException;
import jcuda.Pointer;
import jcuda.runtime.JCuda;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.instructions.gpu.context.CSRPointer;
import org.apache.sysml.runtime.instructions.gpu.context.GPUContext;
import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool;
import org.apache.sysml.runtime.instructions.gpu.context.GPUObject;
import org.apache.sysml.utils.GPUStatistics;

public class GPUMemoryManager {
    protected static final Log LOG = LogFactory.getLog(GPUMemoryManager.class.getName());
    private static final double WARN_UTILIZATION_FACTOR = 0.7;
    public double GPU_MEMORY_UTILIZATION_FACTOR = ConfigurationManager.getDMLConfig().getDoubleValue("sysml.gpu.memory.util.factor");
    private HashMap<Long, Set<Pointer>> rmvarGPUPointers = new HashMap();
    private ArrayList<GPUObject> allocatedGPUObjects = new ArrayList();
    private HashMap<Pointer, Long> allocatedGPUPointers = new HashMap();

    public void addGPUObject(GPUObject gpuObj) {
        this.allocatedGPUObjects.add(gpuObj);
    }

    public void removeGPUObject(GPUObject gpuObj) {
        if (LOG.isDebugEnabled()) {
            LOG.debug("Removing the GPU object: " + gpuObj);
        }
        this.allocatedGPUObjects.removeIf(a -> a.equals(gpuObj));
    }

    public long getSizeAllocatedGPUPointer(Pointer ptr) {
        if (this.allocatedGPUPointers.containsKey(ptr)) {
            return this.allocatedGPUPointers.get(ptr);
        }
        return -1L;
    }

    public GPUMemoryManager(GPUContext gpuCtx) {
        long[] free = new long[]{0L};
        long[] total = new long[]{0L};
        JCuda.cudaMemGetInfo((long[])free, (long[])total);
        if ((double)free[0] < 0.7 * (double)total[0]) {
            LOG.warn("Potential under-utilization: GPU memory - Total: " + (double)total[0] * 1.0E-6 + " MB, Available: " + (double)free[0] * 1.0E-6 + " MB on " + gpuCtx + ". This can happen if there are other processes running on the GPU at the same time.");
        } else {
            LOG.info("GPU memory - Total: " + (double)total[0] * 1.0E-6 + " MB, Available: " + (double)free[0] * 1.0E-6 + " MB on " + gpuCtx);
        }
        if ((double)GPUContextPool.initialGPUMemBudget() > OptimizerUtils.getLocalMemBudget()) {
            LOG.warn("Potential under-utilization: GPU memory (" + GPUContextPool.initialGPUMemBudget() + ") > driver memory budget (" + OptimizerUtils.getLocalMemBudget() + "). Consider increasing the driver memory budget.");
        }
    }

    private Pointer cudaMallocWarnIfFails(Pointer A, long size) {
        try {
            JCuda.cudaMalloc((Pointer)A, (long)size);
            this.allocatedGPUPointers.put(A, size);
            return A;
        }
        catch (CudaException e) {
            LOG.warn("cudaMalloc failed immediately after cudaMemGetInfo reported that memory of size " + size + " is available. This usually happens if there are external programs trying to grab on to memory in parallel.");
            return null;
        }
    }

    public Pointer malloc(String opcode, long size) throws DMLRuntimeException {
        if (size < 0L) {
            throw new DMLRuntimeException("Cannot allocate memory of size " + size);
        }
        long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0L;
        Pointer A = this.getRmvarPointer(opcode, size);
        if (A == null && size <= this.getAvailableMemory()) {
            A = this.cudaMallocWarnIfFails(new Pointer(), size);
            if (LOG.isTraceEnabled()) {
                if (A == null) {
                    LOG.trace("Couldnot allocate a new pointer in the GPU memory:" + size);
                } else {
                    LOG.trace("Allocated a new pointer in the GPU memory:" + size);
                }
            }
        }
        if (A == null) {
            long key = Long.MAX_VALUE;
            for (Long k : this.rmvarGPUPointers.keySet()) {
                key = k > size ? Math.min(key, k) : key;
            }
            if (key != Long.MAX_VALUE) {
                A = this.getRmvarPointer(opcode, key);
                this.guardedCudaFree(A);
                A = this.cudaMallocWarnIfFails(new Pointer(), size);
                if (LOG.isTraceEnabled()) {
                    if (A == null) {
                        LOG.trace("Couldnot reuse non-exact match of rmvarGPUPointers:" + size);
                    } else {
                        LOG.trace("Reuses a non-exact match from rmvarGPUPointers:" + size);
                    }
                }
            }
        }
        if (A == null) {
            HashSet<Pointer> toFree = new HashSet<Pointer>();
            for (Set<Pointer> ptrs : this.rmvarGPUPointers.values()) {
                toFree.addAll(ptrs);
            }
            for (Pointer ptr : toFree) {
                this.guardedCudaFree(ptr);
            }
            if (size <= this.getAvailableMemory()) {
                A = this.cudaMallocWarnIfFails(new Pointer(), size);
                if (LOG.isTraceEnabled()) {
                    if (A == null) {
                        LOG.trace("Couldnot allocate a new pointer in the GPU memory after eager free:" + size);
                    } else {
                        LOG.trace("Allocated a new pointer in the GPU memory after eager free:" + size);
                    }
                }
            }
        }
        this.addMiscTime(opcode, GPUStatistics.cudaAllocTime, GPUStatistics.cudaAllocCount, "a", t0);
        if (A == null) {
            GPUObject toBeRemoved;
            t0 = DMLScript.STATISTICS ? System.nanoTime() : 0L;
            Collections.sort(this.allocatedGPUObjects, new GPUComparator(size));
            while (size > this.getAvailableMemory() && this.allocatedGPUObjects.size() > 0 && !(toBeRemoved = this.allocatedGPUObjects.get(this.allocatedGPUObjects.size() - 1)).isLocked()) {
                if (toBeRemoved.dirty) {
                    toBeRemoved.copyFromDeviceToHost(opcode, true);
                }
                toBeRemoved.clearData(true);
            }
            this.addMiscTime(opcode, GPUStatistics.cudaEvictionCount, GPUStatistics.cudaEvictTime, "evict", t0);
            if (size <= this.getAvailableMemory()) {
                A = this.cudaMallocWarnIfFails(new Pointer(), size);
                if (LOG.isTraceEnabled()) {
                    if (A == null) {
                        LOG.trace("Couldnot allocate a new pointer in the GPU memory after eviction:" + size);
                    } else {
                        LOG.trace("Allocated a new pointer in the GPU memory after eviction:" + size);
                    }
                }
            }
        }
        if (A == null) {
            throw new DMLRuntimeException("There is not enough memory on device for this matrix, request (" + size + "). " + this.toString());
        }
        t0 = DMLScript.STATISTICS ? System.nanoTime() : 0L;
        JCuda.cudaMemset((Pointer)A, (int)0, (long)size);
        this.addMiscTime(opcode, GPUStatistics.cudaMemSet0Time, GPUStatistics.cudaMemSet0Count, "az", t0);
        return A;
    }

    private void guardedCudaFree(Pointer toFree) {
        if (toFree != new Pointer()) {
            if (this.allocatedGPUPointers.containsKey(toFree)) {
                Long size = this.allocatedGPUPointers.remove(toFree);
                if (this.rmvarGPUPointers.containsKey(size) && this.rmvarGPUPointers.get(size).contains(toFree)) {
                    this.remove(this.rmvarGPUPointers, size, toFree);
                }
                if (LOG.isDebugEnabled()) {
                    LOG.debug("Free-ing up the pointer: " + toFree);
                }
                JCuda.cudaFree((Pointer)toFree);
            } else {
                throw new RuntimeException("Attempting to free an unaccounted pointer:" + toFree);
            }
        }
    }

    public void free(String opcode, Pointer toFree, boolean eager) throws DMLRuntimeException {
        Pointer dummy = new Pointer();
        if (toFree == dummy) {
            return;
        }
        if (eager) {
            long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0L;
            this.guardedCudaFree(toFree);
            this.addMiscTime(opcode, GPUStatistics.cudaDeAllocTime, GPUStatistics.cudaDeAllocCount, "f", t0);
        } else {
            if (!this.allocatedGPUPointers.containsKey(toFree)) {
                throw new RuntimeException("ERROR : Internal state corrupted, cache block size map is not aware of a block it trying to free up");
            }
            long size = this.allocatedGPUPointers.get(toFree);
            Set<Pointer> freeList = this.rmvarGPUPointers.get(size);
            if (freeList == null) {
                freeList = new HashSet<Pointer>();
                this.rmvarGPUPointers.put(size, freeList);
            }
            if (freeList.contains(toFree)) {
                throw new RuntimeException("GPU : Internal state corrupted, double free");
            }
            freeList.add(toFree);
        }
    }

    public void clearMemory() throws DMLRuntimeException {
        for (GPUObject gpuObj : this.allocatedGPUObjects) {
            if (gpuObj.isDirty()) {
                LOG.debug("Attempted to free GPU Memory when a block[" + gpuObj + "] is still on GPU memory, copying it back to host.");
                gpuObj.acquireHostRead(null);
            }
            gpuObj.clearData(true);
        }
        this.allocatedGPUObjects.clear();
        HashSet<Pointer> remainingPtr = new HashSet<Pointer>(this.allocatedGPUPointers.keySet());
        for (Pointer toFree : remainingPtr) {
            this.guardedCudaFree(toFree);
        }
    }

    private HashSet<Pointer> getDirtyPointers() {
        HashSet<Pointer> nonTemporaryPointers = new HashSet<Pointer>();
        for (GPUObject o : this.allocatedGPUObjects) {
            CSRPointer p;
            if (!o.isDirty()) continue;
            if (o.isSparse()) {
                p = o.getSparseMatrixCudaPointer();
                if (p == null) {
                    throw new RuntimeException("CSRPointer is null in clearTemporaryMemory");
                }
                if (p.rowPtr != null) {
                    nonTemporaryPointers.add(p.rowPtr);
                }
                if (p.colInd != null) {
                    nonTemporaryPointers.add(p.colInd);
                }
                if (p.val == null) continue;
                nonTemporaryPointers.add(p.val);
                continue;
            }
            p = o.getJcudaDenseMatrixPtr();
            if (p == null) {
                throw new RuntimeException("Pointer is null in clearTemporaryMemory");
            }
            nonTemporaryPointers.add((Pointer)p);
        }
        return nonTemporaryPointers;
    }

    private Set<Pointer> nonIn(Set<Pointer> superset, Set<Pointer> subset) {
        HashSet<Pointer> ret = new HashSet<Pointer>();
        for (Pointer superPtr : superset) {
            if (subset.contains(superPtr)) continue;
            ret.add(superPtr);
        }
        return ret;
    }

    public void clearTemporaryMemory() {
        Set<Pointer> temporaryPointers = this.nonIn(this.allocatedGPUPointers.keySet(), this.getDirtyPointers());
        for (Pointer tmpPtr : temporaryPointers) {
            this.guardedCudaFree(tmpPtr);
        }
    }

    private void addMiscTime(String opcode, LongAdder globalGPUTimer, LongAdder globalGPUCounter, String instructionLevelTimer, long startTime) {
        if (DMLScript.STATISTICS) {
            long totalTime = System.nanoTime() - startTime;
            globalGPUTimer.add(totalTime);
            globalGPUCounter.add(1L);
            if (opcode != null && DMLScript.FINEGRAINED_STATISTICS) {
                GPUStatistics.maintainCPMiscTimes(opcode, instructionLevelTimer, totalTime);
            }
        }
    }

    private void addMiscTime(String opcode, String instructionLevelTimer, long startTime) {
        if (opcode != null && DMLScript.FINEGRAINED_STATISTICS) {
            GPUStatistics.maintainCPMiscTimes(opcode, instructionLevelTimer, System.nanoTime() - startTime);
        }
    }

    private Pointer getRmvarPointer(String opcode, long size) {
        if (this.rmvarGPUPointers.containsKey(size)) {
            if (LOG.isTraceEnabled()) {
                LOG.trace("Getting rmvar-ed pointers for size:" + size);
            }
            long t0 = opcode != null && DMLScript.FINEGRAINED_STATISTICS ? System.nanoTime() : 0L;
            Pointer A = this.remove(this.rmvarGPUPointers, size);
            this.addMiscTime(opcode, "r", t0);
            return A;
        }
        return null;
    }

    private Pointer remove(HashMap<Long, Set<Pointer>> hm, long size) {
        Pointer A = hm.get(size).iterator().next();
        this.remove(hm, size, A);
        return A;
    }

    private void remove(HashMap<Long, Set<Pointer>> hm, long size, Pointer ptr) {
        hm.get(size).remove(ptr);
        if (hm.get(size).isEmpty()) {
            hm.remove(size);
        }
    }

    public String toString() {
        long sizeOfLockedGPUObjects = 0L;
        long numLockedGPUObjects = 0L;
        long sizeOfUnlockedGPUObjects = 0L;
        long numUnlockedGPUObjects = 0L;
        for (GPUObject gpuObj : this.allocatedGPUObjects) {
            try {
                if (gpuObj.isLocked()) {
                    ++numLockedGPUObjects;
                    sizeOfLockedGPUObjects += gpuObj.getSizeOnDevice();
                    continue;
                }
                ++numUnlockedGPUObjects;
                sizeOfUnlockedGPUObjects += gpuObj.getSizeOnDevice();
            }
            catch (DMLRuntimeException e) {
                throw new RuntimeException(e);
            }
        }
        long totalMemoryAllocated = 0L;
        for (Long numBytes : this.allocatedGPUPointers.values()) {
            totalMemoryAllocated += numBytes.longValue();
        }
        return "Num of GPU objects: [unlocked:" + numUnlockedGPUObjects + ", locked:" + numLockedGPUObjects + "]. Size of GPU objects in bytes: [unlocked:" + sizeOfUnlockedGPUObjects + ", locked:" + sizeOfLockedGPUObjects + "]. Total memory allocated by the current GPU context in bytes:" + totalMemoryAllocated;
    }

    public long getAvailableMemory() {
        long[] free = new long[]{0L};
        long[] total = new long[]{0L};
        JCuda.cudaMemGetInfo((long[])free, (long[])total);
        return (long)((double)free[0] * this.GPU_MEMORY_UTILIZATION_FACTOR);
    }

    public static class GPUComparator
    implements Comparator<GPUObject> {
        private long neededSize;

        public GPUComparator(long neededSize) {
            this.neededSize = neededSize;
        }

        @Override
        public int compare(GPUObject p1, GPUObject p2) {
            if (p1.isLocked() && p2.isLocked()) {
                return 0;
            }
            if (p1.isLocked()) {
                return -1;
            }
            if (p2.isLocked()) {
                return 1;
            }
            if (DMLScript.GPU_EVICTION_POLICY == DMLScript.EvictionPolicy.MIN_EVICT) {
                long p1Size = 0L;
                long p2Size = 0L;
                try {
                    p1Size = p1.getSizeOnDevice() - this.neededSize;
                    p2Size = p2.getSizeOnDevice() - this.neededSize;
                }
                catch (DMLRuntimeException e) {
                    throw new RuntimeException(e);
                }
                if (p1Size >= 0L && p2Size >= 0L) {
                    return Long.compare(p2Size, p1Size);
                }
                return Long.compare(p1Size, p2Size);
            }
            return Long.compare(p2.timestamp.get(), p1.timestamp.get());
        }
    }
}

