/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.matrix.data;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.matrix.data.ConvolutionParameters;
import org.apache.sysml.runtime.matrix.data.LibMatrixDNNConv2d;
import org.apache.sysml.runtime.matrix.data.LibMatrixDNNPooling;
import org.apache.sysml.runtime.matrix.data.LibMatrixDNNRelu;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.util.CommonThreadPool;
import org.apache.sysml.runtime.util.ConvolutionUtils;

public class LibMatrixDNN {
    protected static final Log LOG = LogFactory.getLog(LibMatrixDNN.class.getName());
    private static AtomicLong conv2dSparseCount = new AtomicLong(0L);
    private static AtomicLong conv2dDenseCount = new AtomicLong(0L);
    private static AtomicLong conv2dBwdFilterSparseCount = new AtomicLong(0L);
    private static AtomicLong conv2dBwdFilterDenseCount = new AtomicLong(0L);
    private static AtomicLong conv2dBwdDataSparseCount = new AtomicLong(0L);
    private static AtomicLong conv2dBwdDataDenseCount = new AtomicLong(0L);
    private static AtomicLong im2colSparseCount = new AtomicLong(0L);
    private static AtomicLong im2colDenseCount = new AtomicLong(0L);
    private static AtomicLong maxPoolBwdSparseCount = new AtomicLong(0L);
    private static AtomicLong maxPoolBwdDenseCount = new AtomicLong(0L);
    static AtomicLong loopedConvMatMultTime = new AtomicLong(0L);
    static AtomicLong loopedConvIm2ColTime = new AtomicLong(0L);
    static AtomicLong loopedConvBwdFilterMatMultTime = new AtomicLong(0L);
    static AtomicLong loopedConvBwdFilterIm2ColTime = new AtomicLong(0L);
    static AtomicLong loopedConvBwdDataMatMultTime = new AtomicLong(0L);
    static AtomicLong loopedConvBwdDataCol2ImTime = new AtomicLong(0L);

    public static void appendStatistics(StringBuilder sb) {
        if (DMLScript.FINEGRAINED_STATISTICS) {
            sb.append("LibMatrixDNN dense count (conv/bwdF/bwdD/im2col/maxBwd):\t" + conv2dDenseCount.get() + "/" + conv2dBwdFilterDenseCount.get() + "/" + conv2dBwdDataDenseCount.get() + "/" + im2colDenseCount.get() + "/" + maxPoolBwdDenseCount.get() + ".\n");
            sb.append("LibMatrixDNN sparse count (conv/bwdF/bwdD/im2col/maxBwd):\t" + conv2dSparseCount.get() + "/" + conv2dBwdFilterSparseCount.get() + "/" + conv2dBwdDataSparseCount.get() + "/" + im2colSparseCount.get() + "/" + maxPoolBwdSparseCount.get() + ".\n");
            sb.append("LibMatrixDNN conv(im2col/matmult), bwdF (im2col/matmult), bwdD (col2im/matmult) time:\t" + String.format("%.3f", (double)loopedConvIm2ColTime.get() * 1.0E-9) + "/" + String.format("%.3f", (double)loopedConvMatMultTime.get() * 1.0E-9) + "/" + String.format("%.3f", (double)loopedConvBwdFilterIm2ColTime.get() * 1.0E-9) + "/" + String.format("%.3f", (double)loopedConvBwdFilterMatMultTime.get() * 1.0E-9) + "/" + String.format("%.3f", (double)loopedConvBwdDataCol2ImTime.get() * 1.0E-9) + "/" + String.format("%.3f", (double)loopedConvBwdDataMatMultTime.get() * 1.0E-9) + " sec.\n");
        }
    }

    public static void resetStatistics() {
        conv2dDenseCount.set(0L);
        conv2dBwdFilterDenseCount.set(0L);
        conv2dBwdDataDenseCount.set(0L);
        im2colDenseCount.set(0L);
        maxPoolBwdDenseCount.set(0L);
        conv2dSparseCount.set(0L);
        conv2dBwdFilterSparseCount.set(0L);
        conv2dBwdDataSparseCount.set(0L);
        im2colSparseCount.set(0L);
        maxPoolBwdSparseCount.set(0L);
        loopedConvIm2ColTime.set(0L);
        loopedConvMatMultTime.set(0L);
        loopedConvBwdFilterMatMultTime.set(0L);
        loopedConvBwdFilterIm2ColTime.set(0L);
        loopedConvBwdDataMatMultTime.set(0L);
        loopedConvBwdDataCol2ImTime.set(0L);
    }

    public static void conv2d(MatrixBlock input, MatrixBlock filter, MatrixBlock outputBlock, ConvolutionParameters params) throws DMLRuntimeException {
        LibMatrixDNN.checkInputsConv2d(input, filter, outputBlock, params);
        if (params.bias != null && params.bias.isInSparseFormat()) {
            params.bias.sparseToDense();
        }
        long nnz = LibMatrixDNN.execute(LibMatrixDNNConv2d.getConv2dWorkers(params), params);
        outputBlock.setNonZeros(nnz);
        outputBlock.examSparsity();
    }

    public static void conv2dBackwardData(MatrixBlock filter, MatrixBlock dout, MatrixBlock outputBlock, ConvolutionParameters params) throws DMLRuntimeException {
        LibMatrixDNN.checkInputsConv2dBackwardData(filter, dout, outputBlock, params);
        long nnz = LibMatrixDNN.execute(LibMatrixDNNConv2d.getConv2dBackwardDataWorkers(params), params);
        outputBlock.setNonZeros(nnz);
        outputBlock.examSparsity();
    }

    public static void conv2dBackwardFilter(MatrixBlock input, MatrixBlock dout, MatrixBlock outputBlock, ConvolutionParameters params) throws DMLRuntimeException {
        LibMatrixDNN.checkInputsConv2dBackwardFilter(input, dout, outputBlock, params);
        LibMatrixDNN.execute(LibMatrixDNNConv2d.getConv2dBackwardFilterWorkers(params), params);
        outputBlock.recomputeNonZeros();
        outputBlock.examSparsity();
    }

    public static void pooling(MatrixBlock input, MatrixBlock output, ConvolutionParameters params, PoolingType poolType) throws DMLRuntimeException {
        params.input1 = input;
        params.output = output;
        if (input.getNumColumns() != params.C * params.H * params.W || input.getNumRows() != params.N) {
            throw new DMLRuntimeException("Incorrect input dimensions in maxpooling:" + input.getNumRows() + " " + input.getNumColumns() + " " + params.N + " " + params.C * params.H * params.W);
        }
        if (!params.isStride1Pad0() || input.sparse) {
            LibMatrixDNN.fillIndexesArray(params);
        }
        long nnz = LibMatrixDNN.execute(LibMatrixDNNPooling.getPoolingWorkers(params, poolType), params);
        output.setNonZeros(nnz);
        output.examSparsity();
    }

    public static void poolingBackward(MatrixBlock input, MatrixBlock dout, MatrixBlock outputBlock, ConvolutionParameters params, boolean performReluBackward, PoolingType poolType) throws DMLRuntimeException {
        params.input1 = input;
        params.input2 = dout;
        params.output = outputBlock;
        if (poolType == PoolingType.MAX && (input.getNumColumns() != params.C * params.H * params.W || input.getNumRows() != params.N)) {
            throw new DMLRuntimeException("Incorrect input dimensions in maxpooling_backward:" + input.getNumRows() + " " + input.getNumColumns() + " " + params.N + " " + params.K * params.P * params.Q);
        }
        if (dout.getNumColumns() != params.C * params.P * params.Q || dout.getNumRows() != params.N) {
            throw new DMLRuntimeException("Incorrect dout dimensions in pooling_backward:" + input.getNumRows() + " " + input.getNumColumns() + " " + params.N + " " + params.K * params.P * params.Q);
        }
        if (DMLScript.FINEGRAINED_STATISTICS) {
            boolean isSparse;
            boolean bl = poolType == PoolingType.MAX ? input.isInSparseFormat() || dout.isInSparseFormat() : (isSparse = dout.isInSparseFormat());
            if (isSparse) {
                maxPoolBwdSparseCount.addAndGet(1L);
            } else {
                maxPoolBwdDenseCount.addAndGet(1L);
            }
        }
        if (params.output.isInSparseFormat()) {
            throw new DMLRuntimeException("Sparse pooling_backward is not supported");
        }
        if (poolType == PoolingType.AVG) {
            LibMatrixDNN.fillIndexesArray(params);
        } else if (!params.input1.isInSparseFormat() || params.input2.isInSparseFormat()) {
            LibMatrixDNN.fillIndexesArray(params);
        }
        long nnz = LibMatrixDNN.execute(LibMatrixDNNPooling.getPoolingBackwardWorkers(params, performReluBackward, poolType), params);
        outputBlock.setNonZeros(nnz);
        outputBlock.examSparsity();
    }

    public static void reluBackward(MatrixBlock input, MatrixBlock dout, MatrixBlock outputBlock, int numThreads) throws DMLRuntimeException {
        int N = input.getNumRows();
        ConvolutionParameters params = new ConvolutionParameters(N, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, numThreads);
        params.input1 = input;
        params.input2 = dout;
        params.output = outputBlock;
        if (input.getNumRows() != dout.getNumRows() || input.getNumColumns() != dout.getNumColumns()) {
            throw new DMLRuntimeException("Incorrect dimensions for relu_backward:" + input.getNumRows() + " != " + dout.getNumRows() + " || " + input.getNumColumns() + " != " + dout.getNumColumns());
        }
        long nnz = LibMatrixDNN.execute(LibMatrixDNNRelu.getReluBackwardWorkers(params), params);
        outputBlock.setNonZeros(nnz);
        outputBlock.examSparsity();
    }

    public static void biasAdd(MatrixBlock input, MatrixBlock bias, MatrixBlock outputBlock, int numThreads) throws DMLRuntimeException {
        int N = input.getNumRows();
        int K = bias.getNumRows();
        int PQ = input.getNumColumns() / K;
        if (bias.getNumColumns() != 1 || input.getNumColumns() % K != 0) {
            throw new DMLRuntimeException("Incorrect inputs for bias_add: input[" + N + " X " + input.getNumColumns() + "] and bias[" + K + " X " + bias.getNumColumns() + "]");
        }
        double[] outputArray = outputBlock.getDenseBlockValues();
        if (input.isEmptyBlock()) {
            for (int n = 0; n < N; ++n) {
                ConvolutionUtils.fillBias(bias, outputArray, n, n + 1, N, K, PQ);
            }
        } else {
            outputBlock.copy(input);
            int index = 0;
            if (bias.isInSparseFormat()) {
                bias.sparseToDense();
            }
            double[] biasArr = bias.getDenseBlockValues();
            for (int n = 0; n < N; ++n) {
                for (int k = 0; k < K; ++k) {
                    double biasVal = biasArr[k];
                    for (int pq = 0; pq < PQ; ++pq) {
                        int n2 = index++;
                        outputArray[n2] = outputArray[n2] + biasVal;
                    }
                }
            }
        }
        outputBlock.recomputeNonZeros();
        outputBlock.examSparsity();
    }

    public static void biasMultiply(MatrixBlock input, MatrixBlock bias, MatrixBlock outputBlock, int numThreads) throws DMLRuntimeException {
        int N = input.getNumRows();
        int K = bias.getNumRows();
        int PQ = input.getNumColumns() / K;
        ConvolutionParameters params = new ConvolutionParameters(N, PQ, -1, -1, K, -1, -1, -1, -1, -1, -1, numThreads);
        params.input1 = input;
        params.input2 = bias;
        params.output = outputBlock;
        if (bias.getNumColumns() != 1 || input.getNumColumns() % K != 0) {
            throw new DMLRuntimeException("Incorrect inputs for bias_multiply: input[" + N + " X " + input.getNumColumns() + "] and bias[" + K + " X " + bias.getNumColumns() + "]");
        }
        if (!input.isEmptyBlock() && !bias.isEmptyBlock()) {
            outputBlock.copy(input);
            if (bias.isInSparseFormat()) {
                bias.sparseToDense();
            }
            double[] biasArr = bias.getDenseBlockValues();
            if (!input.isInSparseFormat()) {
                double[] outputArray = outputBlock.getDenseBlockValues();
                int index = 0;
                for (int n = 0; n < N; ++n) {
                    for (int k = 0; k < K; ++k) {
                        double biasVal = biasArr[k];
                        for (int pq = 0; pq < PQ; ++pq) {
                            int n2 = index++;
                            outputArray[n2] = outputArray[n2] * biasVal;
                        }
                    }
                }
            } else {
                for (int k = 0; k < K; ++k) {
                    if (biasArr[k] != 0.0) continue;
                    for (int n = 0; n < N; ++n) {
                        outputBlock.sparseBlock.deleteIndexRange(n, k * PQ, (k + 1) * PQ);
                    }
                }
                for (int n = 0; n < N; ++n) {
                    if (outputBlock.sparseBlock.isEmpty(n)) continue;
                    int apos = outputBlock.sparseBlock.pos(n);
                    int alen = outputBlock.sparseBlock.size(n);
                    int[] aix = outputBlock.sparseBlock.indexes(n);
                    double[] avals = outputBlock.sparseBlock.values(n);
                    for (int j = apos; j < apos + alen; ++j) {
                        int k = aix[j] % PQ;
                        if (biasArr[k] == 0.0) continue;
                        int n3 = j;
                        avals[n3] = avals[n3] * biasArr[k];
                    }
                }
            }
            params.output.recomputeNonZeros();
            params.output.examSparsity();
        } else {
            params.output.setNonZeros(0L);
        }
    }

    private static long execute(ArrayList<Callable<Long>> tasks, ConvolutionParameters params) throws DMLRuntimeException {
        int k = OptimizerUtils.getConstrainedNumThreads(params.numThreads);
        long lnnz = 0L;
        try {
            if (k == 1) {
                for (Callable<Long> task : tasks) {
                    lnnz += task.call().longValue();
                }
            } else {
                ExecutorService pool = CommonThreadPool.get(Math.min(k, params.N));
                List<Future<Long>> taskret = pool.invokeAll(tasks);
                pool.shutdown();
                for (Future<Long> task : taskret) {
                    lnnz += task.get().longValue();
                }
            }
        }
        catch (Exception e) {
            throw new DMLRuntimeException("Error while executing multi-threaded tasks", e);
        }
        return lnnz;
    }

    private static void checkOrThrowException(String msg, long lhs, long rhs) throws DMLRuntimeException {
        if (lhs != rhs) {
            throw new DMLRuntimeException(msg + ":" + lhs + " != " + rhs);
        }
    }

    private static void checkOrThrowException(String msg, long lhs, long rhs1, long rhs2, long rhs3) throws DMLRuntimeException {
        if (lhs != rhs1 * rhs2 * rhs3) {
            throw new DMLRuntimeException(msg + ":" + lhs + " != (" + rhs1 + " * " + rhs2 + " * " + rhs3);
        }
    }

    static void checkInputsConv2dBackwardData(MatrixBlock filter, MatrixBlock dout, MatrixBlock outputBlock, ConvolutionParameters params) throws DMLRuntimeException {
        params.input1 = filter;
        params.input2 = dout;
        params.output = outputBlock;
        LibMatrixDNN.checkOrThrowException("Incorrect input to conv2d_backward_data: Number of rows of input filter != number of filters in filter_shape", filter.getNumRows(), params.K);
        LibMatrixDNN.checkOrThrowException("Incorrect input to conv2d_backward_data: Number of columns of input filter != channels*filter_height*filter_height in filter_shape", filter.getNumColumns(), params.C, params.R, params.S);
        LibMatrixDNN.checkOrThrowException("Incorrect input to conv2d_backward_data: Number of rows of input errors != batch size in input_shape", dout.getNumRows(), params.N);
        LibMatrixDNN.checkOrThrowException("Incorrect input to conv2d_backward_data: Number of columns of input errors != expected input error channels*height*width", dout.getNumColumns(), params.K, params.P, params.Q);
        if (params.stride_h <= 0 || params.stride_w <= 0) {
            throw new DMLRuntimeException("Only positive strides supported:" + params.stride_h + ", " + params.stride_w);
        }
        if (DMLScript.FINEGRAINED_STATISTICS) {
            if (filter.isInSparseFormat() || dout.isInSparseFormat()) {
                conv2dBwdDataSparseCount.addAndGet(1L);
            } else {
                conv2dBwdDataDenseCount.addAndGet(1L);
            }
        }
    }

    static void checkInputsConv2dBackwardFilter(MatrixBlock input, MatrixBlock dout, MatrixBlock outputBlock, ConvolutionParameters params) throws DMLRuntimeException {
        params.input1 = input;
        params.input2 = dout;
        params.output = outputBlock;
        LibMatrixDNN.checkOrThrowException("Incorrect input to conv2d_backward_filter: Number of rows of input data != batch size in input_shape", input.getNumRows(), params.N);
        LibMatrixDNN.checkOrThrowException("Incorrect input to conv2d_backward_filter: Number of columns of input data != channels*input_height*input_height in input_shape", input.getNumColumns(), params.C, params.H, params.W);
        LibMatrixDNN.checkOrThrowException("Incorrect input to conv2d_backward_filter: Number of rows of input errors != batch size in input_shape", dout.getNumRows(), params.N);
        LibMatrixDNN.checkOrThrowException("Incorrect input to conv2d_backward_filter: Number of columns of input errors != expected input error channels*height*width", dout.getNumColumns(), params.K, params.P, params.Q);
        if (params.stride_h <= 0 || params.stride_w <= 0) {
            throw new DMLRuntimeException("Only positive strides supported:" + params.stride_h + ", " + params.stride_w);
        }
        if (DMLScript.FINEGRAINED_STATISTICS) {
            if (input.isInSparseFormat() || dout.isInSparseFormat()) {
                conv2dBwdFilterSparseCount.addAndGet(1L);
            } else {
                conv2dBwdFilterDenseCount.addAndGet(1L);
            }
        }
    }

    static void checkInputsConv2d(MatrixBlock input, MatrixBlock filter, MatrixBlock outputBlock, ConvolutionParameters params) throws DMLRuntimeException {
        params.input1 = input;
        params.input2 = filter;
        params.output = outputBlock;
        LibMatrixDNN.checkOrThrowException("Incorrect input to conv2d: Number of rows of input filter != number of filters in filter_shape", filter.getNumRows(), params.K);
        LibMatrixDNN.checkOrThrowException("Incorrect input to conv2d: Number of columns of input filter != channels*filter_height*filter_height in filter_shape", filter.getNumColumns(), params.C, params.R, params.S);
        LibMatrixDNN.checkOrThrowException("Incorrect input to conv2d: Number of rows of input data != batch size in input_shape", input.getNumRows(), params.N);
        LibMatrixDNN.checkOrThrowException("Incorrect input to conv2d: Number of columns of input data != channels*input_height*input_height in input_shape", input.getNumColumns(), params.C, params.H, params.W);
        if (params.stride_h <= 0 || params.stride_w <= 0) {
            throw new DMLRuntimeException("Only positive strides supported:" + params.stride_h + ", " + params.stride_w);
        }
        if (DMLScript.FINEGRAINED_STATISTICS) {
            if (input.isInSparseFormat() || filter.isInSparseFormat()) {
                conv2dSparseCount.addAndGet(1L);
            } else {
                conv2dDenseCount.addAndGet(1L);
            }
        }
    }

    private static void fillIndexesArray(ConvolutionParameters params) {
        params.start_indexes_h = new int[params.P];
        params.end_indexes_h = new int[params.P];
        params.start_indexes_w = new int[params.Q];
        params.end_indexes_w = new int[params.Q];
        int p = 0;
        int ix = -params.pad_h;
        while (p < params.P) {
            params.start_indexes_h[p] = Math.max(ix, 0);
            params.end_indexes_h[p] = Math.min(ix + params.R, params.H);
            ++p;
            ix += params.stride_h;
        }
        int q = 0;
        ix = -params.pad_w;
        while (q < params.Q) {
            params.start_indexes_w[q] = Math.max(ix, 0);
            params.end_indexes_w[q] = Math.min(ix + params.S, params.W);
            ++q;
            ix += params.stride_w;
        }
    }

    public static enum PoolingType {
        MAX,
        AVG;

    }
}

