/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.matrix.data;

import java.util.Arrays;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.matrix.data.ConvolutionParameters;
import org.apache.sysml.runtime.matrix.data.LibMatrixDNNHelper;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.SparseBlock;

public class LibMatrixDNNIm2Col {
    public static void im2col(MatrixBlock in, MatrixBlock out, int r, ConvolutionParameters params, boolean trans) {
        LibMatrixDNNIm2Col.im2col(in, out, r, params.C, params.R, params.S, params.H, params.W, params.P, params.Q, params.stride_h, params.stride_w, params.pad_h, params.pad_w, trans);
    }

    public static void im2col(MatrixBlock in, MatrixBlock out, int r, int C, int R, int S, int H, int W, int P, int Q, int stride_h, int stride_w, int pad_h, int pad_w, boolean trans) {
        boolean stride1Pad0;
        boolean bl = stride1Pad0 = stride_h == 1 && stride_w == 1 && pad_h == 0 && pad_w == 0;
        if (!in.sparse && stride1Pad0 && !trans) {
            LibMatrixDNNIm2Col.im2colDenseStride1Pad0(in.getDenseBlockValues(), out.getDenseBlockValues(), r, C, R, S, H, W, P, Q);
        } else if (!in.sparse) {
            LibMatrixDNNIm2Col.im2colDense(in.getDenseBlockValues(), out.getDenseBlockValues(), r, C, R, S, H, W, P, Q, stride_h, stride_w, pad_h, pad_w, trans);
        } else {
            LibMatrixDNNIm2Col.im2colSparse(in, out, r, C, R, S, H, W, P, Q, stride_h, stride_w, pad_h, pad_w, trans);
        }
    }

    public static void im2colDenseStride1Pad0(double[] in, double[] out, int r, int C, int R, int S, int H, int W, int P, int Q) {
        int nOffset = r * C * H * W;
        int CRS = C * R * S;
        for (int c = 0; c < CRS; ++c) {
            int wOffset = c % S;
            int hOffset = c / S % R;
            int cInput = c / R / S;
            for (int h = 0; h < P; ++h) {
                int hPadded = h + hOffset;
                int outOffset = (c * P + h) * Q;
                int inputOffset = nOffset + (cInput * H + hPadded) * W;
                System.arraycopy(in, inputOffset + wOffset, out, outOffset, Q);
                int w = Q - 1;
                int wPadded = w + wOffset;
                boolean assign = hPadded < H && wPadded < W;
                out[outOffset + w] = assign ? in[inputOffset + wPadded] : 0.0;
            }
        }
    }

    public static void im2colDense(double[] in, double[] out, int r, int C, int R, int S, int H, int W, int P, int Q, int stride_h, int stride_w, int pad_h, int pad_w, boolean trans) {
        Arrays.fill(out, 0.0);
        int CHW = C * H * W;
        int CRS = C * R * S;
        int nOffset = r * CHW;
        for (int c = 0; c < CRS; ++c) {
            int wOffset = c % S;
            int hOffset = c / S % R;
            int cInput = c / R / S;
            for (int h = 0; h < P; ++h) {
                int outOffset = trans ? c + h * Q * CRS : (c * P + h) * Q;
                int hPadded = h * stride_h - pad_h + hOffset;
                int inputOffset = nOffset + (cInput * H + hPadded) * W;
                if (hPadded < 0 || hPadded >= H) continue;
                for (int w = 0; w < Q; ++w) {
                    int wPadded = w * stride_w - pad_w + wOffset;
                    if (wPadded < 0 || wPadded >= W) continue;
                    out[outOffset + (trans ? w * CRS : w)] = in[inputOffset + wPadded];
                }
            }
        }
    }

    public static void im2colSparse(MatrixBlock in, MatrixBlock out, int r, int C, int R, int S, int H, int W, int P, int Q, int stride_h, int stride_w, int pad_h, int pad_w, boolean trans) {
        out.reset();
        SparseBlock sblock = in.sparseBlock;
        if (sblock.isEmpty(r)) {
            return;
        }
        int apos = sblock.pos(r);
        int alen = sblock.size(r);
        int[] aix = sblock.indexes(r);
        double[] avals = sblock.values(r);
        boolean simple = stride_h == 1 && stride_w == 1 && pad_h == 0 && pad_w == 0 && W == S && Q == 1;
        int RS = R * S;
        LibMatrixDNNHelper.CellIndex3 ix = new LibMatrixDNNHelper.CellIndex3();
        for (int j = apos; j < apos + alen; ++j) {
            int chw = aix[j];
            ix = LibMatrixDNNHelper.computeTensorIndexes(chw, H, W, ix);
            if (simple) {
                LibMatrixDNNIm2Col.appendInputValueToIm2colOutputSimple(out, ix.ix1, ix.ix2, ix.ix3, avals[j], R, S, RS, P, trans);
                continue;
            }
            LibMatrixDNNIm2Col.appendInputValueToIm2colOutput(out, ix.ix1, ix.ix2, ix.ix3, avals[j], R, S, RS, P, Q, stride_h, stride_w, pad_h, pad_w, trans);
        }
        out.sortSparseRows();
    }

    private static void appendInputValueToIm2colOutput(MatrixBlock output, int c, int h, int w, double value, int R, int S, int RS, int P, int Q, int stride_h, int stride_w, int pad_h, int pad_w, boolean trans) {
        int rMin = Math.max(0, h + pad_h - P * stride_h + 1);
        int rMax = Math.min(R - 1, h + pad_h);
        int sMin = Math.max(0, w + pad_w - Q * stride_w + 1);
        int sMax = Math.min(S - 1, w + pad_w);
        rMin += Math.min((h - rMin + pad_h) % stride_h, rMax - rMin + 1);
        sMin += Math.min((w - sMin + pad_w) % stride_w, sMax - sMin + 1);
        int r = rMin;
        int ix = c * RS + rMin * S;
        while (r <= rMax) {
            int pQ = (h - r + pad_h) / stride_h * Q;
            int s = sMin;
            int ws = w - sMin + pad_w;
            while (s <= sMax) {
                int q = ws / stride_w;
                output.appendValue(trans ? pQ + q : ix + s, trans ? ix + s : pQ + q, value);
                s += stride_w;
                ws -= stride_w;
            }
            r += stride_h;
            ix += stride_h * S;
        }
    }

    private static void appendInputValueToIm2colOutputSimple(MatrixBlock output, int c, int h, int w, double value, int R, int S, int RS, int P, boolean trans) {
        int rMin = Math.max(0, h - P + 1);
        int rMax = Math.min(R - 1, h);
        int cix = c * RS + w + rMin * S;
        int p = h - rMin;
        while (p >= h - rMax) {
            output.appendValue(trans ? p : cix, trans ? cix : p, value);
            --p;
            cix += S;
        }
    }

    public static void col2imOverSingleImage(int outputN, MatrixBlock input, ConvolutionParameters params) throws DMLRuntimeException {
        if (input.rlen != params.P * params.Q || input.clen != params.C * params.R * params.S) {
            throw new DMLRuntimeException("Incorrect input dimensions");
        }
        double[] outputArray = null;
        if (params.output.isInSparseFormat()) {
            throw new DMLRuntimeException("Only dense output is implemented");
        }
        outputArray = params.output.getDenseBlockValues();
        if (!input.isInSparseFormat()) {
            double[] inputArray = input.getDenseBlockValues();
            LibMatrixDNNIm2Col.col2IMDenseInput(0, outputN, inputArray, outputArray, params);
        } else if (!input.isEmptyBlock()) {
            int outOffset = outputN * params.C * params.H * params.W;
            int HW = params.H * params.W;
            LibMatrixDNNHelper.CellIndex3 ix = new LibMatrixDNNHelper.CellIndex3();
            SparseBlock sblock = input.sparseBlock;
            for (int i = 0; i < input.getNumRows(); ++i) {
                if (sblock.isEmpty(i)) continue;
                ix = LibMatrixDNNHelper.computeTensorIndexes(i, params.P, params.Q, ix);
                int tmpP = ix.ix2 * params.stride_h - params.pad_h;
                int tmpQ = ix.ix3 * params.stride_w - params.pad_w;
                if (ix.ix1 != 0) {
                    throw new DMLRuntimeException("Incorrect tensor indexes: " + ix + ", " + params.P + " " + params.Q);
                }
                int apos = sblock.pos(i);
                int alen = sblock.size(i);
                int[] aix = sblock.indexes(i);
                double[] avals = sblock.values(i);
                for (int j = apos; j < apos + alen; ++j) {
                    int outIndex;
                    ix = LibMatrixDNNHelper.computeTensorIndexes(aix[j], params.R, params.S, ix);
                    int h = tmpP + ix.ix2;
                    int w = tmpQ + ix.ix3;
                    if (h < 0 || h >= params.H || w < 0 || w >= params.W) continue;
                    int n = outIndex = outOffset + ix.ix1 * HW + h * params.W + w;
                    outputArray[n] = outputArray[n] + avals[j];
                }
            }
        }
    }

    private static void col2IMDenseInput(int inputN, int outputN, double[] inputArray, double[] outputArray, ConvolutionParameters params) throws DMLRuntimeException {
        int outputNOffset = outputN * params.C * params.H * params.W;
        int HW = params.H * params.W;
        int inputNPQ = inputN * params.P * params.Q;
        int CRS = params.C * params.R * params.S;
        int RS = params.R * params.S;
        for (int p = 0; p < params.P; ++p) {
            int hOffset = p * params.stride_h - params.pad_h;
            int rStart = Math.max(0, -hOffset);
            int rEnd = Math.min(params.R, params.H - hOffset);
            for (int q = 0; q < params.Q; ++q) {
                int wOffset = q * params.stride_w - params.pad_w;
                int sStart = Math.max(0, -wOffset);
                int sEnd = Math.min(params.S, params.W - wOffset);
                int tempOffset = (inputNPQ + p * params.Q + q) * CRS;
                for (int c = 0; c < params.C; ++c) {
                    int outOffset = outputNOffset + c * HW;
                    int inputOffset = tempOffset + c * RS;
                    for (int r = rStart; r < rEnd; ++r) {
                        for (int s = sStart; s < sEnd; ++s) {
                            int outIndex;
                            int inputIndex = inputOffset + r * params.S + s;
                            int n = outIndex = outOffset + (hOffset + r) * params.W + wOffset + s;
                            outputArray[n] = outputArray[n] + inputArray[inputIndex];
                        }
                    }
                }
            }
        }
    }

    public static void preallocateSparseOutput(MatrixBlock in, MatrixBlock out) {
        if (!in.sparse) {
            return;
        }
        int estnnz = (int)Math.ceil(4.0 * in.getSparsity() * (double)out.clen);
        for (int r = 0; r < out.rlen; ++r) {
            out.getSparseBlock().allocate(r, Math.max(Math.min(estnnz, out.clen), 16));
        }
    }
}

