/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.matrix.data;

import java.util.ArrayList;
import java.util.concurrent.Callable;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.matrix.data.ConvolutionParameters;
import org.apache.sysml.runtime.matrix.data.DenseBlock;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.SparseBlock;

public class LibMatrixDNNRelu {
    public static ArrayList<Callable<Long>> getReluBackwardWorkers(ConvolutionParameters params) throws DMLRuntimeException {
        ArrayList<Callable<Long>> ret = new ArrayList<Callable<Long>>();
        int k = OptimizerUtils.getConstrainedNumThreads(params.numThreads);
        int taskSize = (int)Math.ceil((double)params.N / (double)k);
        int i = 0;
        while (i * taskSize < params.N) {
            ret.add(new ReluBackward(i * taskSize, Math.min((i + 1) * taskSize, params.N), params));
            ++i;
        }
        return ret;
    }

    private static void reluBackwardDenseDense(DenseBlock a, DenseBlock b, DenseBlock c, int n, int rl, int ru) {
        for (int i = rl; i < ru; ++i) {
            double[] avals = a.values(i);
            double[] bvals = b.values(i);
            double[] cvals = c.values(i);
            int ix = a.pos(i);
            for (int j = 0; j < n; ++j) {
                cvals[ix + j] = avals[ix + j] > 0.0 ? bvals[ix + j] : 0.0;
            }
        }
    }

    private static void reluBackwardDenseSparse(DenseBlock a, SparseBlock b, SparseBlock c, int rl, int ru) {
        for (int i = rl; i < ru; ++i) {
            if (b.isEmpty(i)) continue;
            int bpos = b.pos(i);
            int blen = b.size(i);
            int[] bix = b.indexes(i);
            double[] bvals = b.values(i);
            double[] avals = a.values(i);
            int aix = a.pos(i);
            c.allocate(i, blen);
            for (int k = bpos; k < bpos + blen; ++k) {
                c.append(i, bix[k], avals[aix + bix[k]] > 0.0 ? bvals[k] : 0.0);
            }
        }
    }

    private static void reluBackwardSparseDense(SparseBlock a, DenseBlock b, SparseBlock c, int rl, int ru) {
        for (int i = rl; i < ru; ++i) {
            if (a.isEmpty(i)) continue;
            int apos = a.pos(i);
            int alen = a.size(i);
            int[] aix = a.indexes(i);
            double[] avals = a.values(i);
            double[] bvals = b.values(i);
            int bix = b.pos(i);
            c.allocate(i, alen);
            for (int k = apos; k < apos + alen; ++k) {
                c.append(i, aix[k], avals[k] > 0.0 ? bvals[bix + aix[k]] : 0.0);
            }
        }
    }

    private static void reluBackwardSparseSparse(SparseBlock a, SparseBlock b, SparseBlock c, int rl, int ru) {
        for (int i = rl; i < ru; ++i) {
            if (a.isEmpty(i) || b.isEmpty(i)) continue;
            int bpos = b.pos(i);
            int blen = b.size(i);
            int[] bix = b.indexes(i);
            double[] bvals = b.values(i);
            c.allocate(i, blen);
            for (int k = bpos; k < bpos + blen; ++k) {
                c.append(i, bix[k], a.get(i, bix[k]) > 0.0 ? bvals[k] : 0.0);
            }
        }
    }

    public static class ReluBackward
    implements Callable<Long> {
        public final int _rl;
        public final int _ru;
        private final ConvolutionParameters _params;

        public ReluBackward(int rl, int ru, ConvolutionParameters params) {
            this._rl = rl;
            this._ru = ru;
            this._params = params;
        }

        @Override
        public Long call() throws Exception {
            MatrixBlock m1 = this._params.input1;
            MatrixBlock m2 = this._params.input2;
            MatrixBlock out = this._params.output;
            int n = m1.getNumColumns();
            if (m1.isEmptyBlock(false) || m2.isEmptyBlock(false)) {
                return 0L;
            }
            if (!m1.isInSparseFormat() && !m2.isInSparseFormat()) {
                LibMatrixDNNRelu.reluBackwardDenseDense(m1.getDenseBlock(), m2.getDenseBlock(), out.getDenseBlock(), n, this._rl, this._ru);
            } else if (!m1.isInSparseFormat() && m2.isInSparseFormat()) {
                LibMatrixDNNRelu.reluBackwardDenseSparse(m1.getDenseBlock(), m2.getSparseBlock(), out.getSparseBlock(), this._rl, this._ru);
            } else if (m1.isInSparseFormat() && !m2.isInSparseFormat()) {
                LibMatrixDNNRelu.reluBackwardSparseDense(m1.getSparseBlock(), m2.getDenseBlock(), out.getSparseBlock(), this._rl, this._ru);
            } else {
                LibMatrixDNNRelu.reluBackwardSparseSparse(m1.getSparseBlock(), m2.getSparseBlock(), out.getSparseBlock(), this._rl, this._ru);
            }
            return out.recomputeNonZeros(this._rl, this._ru - 1);
        }
    }
}

