/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.codegen;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.stream.IntStream;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.codegen.LibSpoofPrimitives;
import org.apache.sysml.runtime.codegen.SpoofOperator;
import org.apache.sysml.runtime.compress.CompressedMatrixBlock;
import org.apache.sysml.runtime.instructions.cp.DoubleObject;
import org.apache.sysml.runtime.instructions.cp.ScalarObject;
import org.apache.sysml.runtime.matrix.data.LibMatrixMult;
import org.apache.sysml.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.SparseBlock;
import org.apache.sysml.runtime.matrix.data.SparseRow;
import org.apache.sysml.runtime.matrix.data.SparseRowVector;
import org.apache.sysml.runtime.util.UtilFunctions;

public abstract class SpoofRowwise
extends SpoofOperator {
    private static final long serialVersionUID = 6242910797139642998L;
    protected final RowType _type;
    protected final boolean _cbind0;
    protected final boolean _tB1;
    protected final int _reqVectMem;

    public SpoofRowwise(RowType type, boolean cbind0, boolean tB1, int reqVectMem) {
        this._type = type;
        this._cbind0 = cbind0;
        this._tB1 = tB1;
        this._reqVectMem = reqVectMem;
    }

    public RowType getRowType() {
        return this._type;
    }

    public boolean isCBind0() {
        return this._cbind0;
    }

    public int getNumIntermediates() {
        return this._reqVectMem;
    }

    @Override
    public String getSpoofType() {
        return "RA" + this.getClass().getName().split("\\.")[1];
    }

    @Override
    public ScalarObject execute(ArrayList<MatrixBlock> inputs, ArrayList<ScalarObject> scalarObjects, int k) throws DMLRuntimeException {
        MatrixBlock out = k > 1 ? this.execute(inputs, scalarObjects, new MatrixBlock(1, 1, false), k) : this.execute(inputs, scalarObjects, new MatrixBlock(1, 1, false));
        return new DoubleObject(out.quickGetValue(0, 0));
    }

    @Override
    public MatrixBlock execute(ArrayList<MatrixBlock> inputs, ArrayList<ScalarObject> scalarObjects, MatrixBlock out) throws DMLRuntimeException {
        return this.execute(inputs, scalarObjects, out, true, false);
    }

    public MatrixBlock execute(ArrayList<MatrixBlock> inputs, ArrayList<ScalarObject> scalarObjects, MatrixBlock out, boolean allocTmp, boolean aggIncr) throws DMLRuntimeException {
        MatrixBlock a;
        int n2;
        if (inputs == null || inputs.size() < 1 || out == null) {
            throw new RuntimeException("Invalid input arguments.");
        }
        int m = inputs.get(0).getNumRows();
        int n = inputs.get(0).getNumColumns();
        int n3 = n2 = this._type.isRowTypeB1() || SpoofRowwise.hasMatrixSideInput(inputs) ? SpoofRowwise.getMinColsMatrixSideInputs(inputs) : -1;
        if (!aggIncr || !out.isAllocated()) {
            this.allocateOutputMatrix(m, n, n2, out);
        }
        double[] c = out.getDenseBlock();
        boolean flipOut = this._type.isRowTypeB1ColumnAgg() && LibSpoofPrimitives.isFlipOuter(out.getNumRows(), out.getNumColumns());
        SpoofOperator.SideInput[] b = this.prepInputMatrices(inputs, 1, inputs.size() - 1, true, this._tB1);
        double[] scalars = SpoofRowwise.prepInputScalars(scalarObjects);
        if (allocTmp && this._reqVectMem > 0) {
            LibSpoofPrimitives.setupThreadLocalMemory(this._reqVectMem, n, n2);
        }
        if ((a = inputs.get(0)) instanceof CompressedMatrixBlock) {
            this.executeCompressed((CompressedMatrixBlock)a, b, scalars, c, n, 0, m);
        } else if (!a.isInSparseFormat()) {
            this.executeDense(a.getDenseBlock(), b, scalars, c, n, 0, m);
        } else {
            this.executeSparse(a.getSparseBlock(), b, scalars, c, n, 0, m);
        }
        if (allocTmp && this._reqVectMem > 0) {
            LibSpoofPrimitives.cleanupThreadLocalMemory();
        }
        out.recomputeNonZeros();
        if (flipOut) {
            this.fixTransposeDimensions(out);
            out = LibMatrixReorg.transpose(out, new MatrixBlock(out.getNumColumns(), out.getNumRows(), false));
        }
        out.examSparsity();
        return out;
    }

    @Override
    public MatrixBlock execute(ArrayList<MatrixBlock> inputs, ArrayList<ScalarObject> scalarObjects, MatrixBlock out, int k) throws DMLRuntimeException {
        if (k <= 1 || this._type.isColumnAgg() && !LibMatrixMult.checkParColumnAgg(inputs.get(0), k, false) || SpoofRowwise.getTotalInputNnz(inputs) < 0x100000L) {
            return this.execute(inputs, scalarObjects, out);
        }
        if (inputs == null || inputs.size() < 1 || out == null) {
            throw new RuntimeException("Invalid input arguments.");
        }
        int m = inputs.get(0).getNumRows();
        int n = inputs.get(0).getNumColumns();
        int n2 = this._type.isRowTypeB1() || SpoofRowwise.hasMatrixSideInput(inputs) ? SpoofRowwise.getMinColsMatrixSideInputs(inputs) : -1;
        this.allocateOutputMatrix(m, n, n2, out);
        boolean flipOut = this._type.isRowTypeB1ColumnAgg() && LibSpoofPrimitives.isFlipOuter(out.getNumRows(), out.getNumColumns());
        SpoofOperator.SideInput[] b = this.prepInputMatrices(inputs, 1, inputs.size() - 1, true, this._tB1);
        double[] scalars = SpoofRowwise.prepInputScalars(scalarObjects);
        ExecutorService pool = Executors.newFixedThreadPool(k);
        int nk = UtilFunctions.roundToNext(Math.min(8 * k, m / 32), k);
        int blklen = (int)Math.ceil((double)m / (double)nk);
        try {
            if (this._type.isColumnAgg() || this._type == RowType.FULL_AGG) {
                ArrayList<ParColAggTask> tasks = new ArrayList<ParColAggTask>();
                int i = 0;
                while (i < nk & i * blklen < m) {
                    tasks.add(new ParColAggTask(inputs.get(0), b, scalars, n, n2, i * blklen, Math.min((i + 1) * blklen, m)));
                    ++i;
                }
                List taskret = pool.invokeAll(tasks);
                int len = this._type.isColumnAgg() ? out.getNumRows() * out.getNumColumns() : 1;
                for (Future task : taskret) {
                    LibMatrixMult.vectAdd((double[])task.get(), out.getDenseBlock(), 0, 0, len);
                }
                out.recomputeNonZeros();
            } else {
                ArrayList<ParExecTask> tasks = new ArrayList<ParExecTask>();
                int i = 0;
                while (i < nk & i * blklen < m) {
                    tasks.add(new ParExecTask(inputs.get(0), b, out, scalars, n, n2, i * blklen, Math.min((i + 1) * blklen, m)));
                    ++i;
                }
                List taskret = pool.invokeAll(tasks);
                long nnz = 0L;
                for (Future task : taskret) {
                    nnz += ((Long)task.get()).longValue();
                }
                out.setNonZeros(nnz);
            }
            pool.shutdown();
            if (flipOut) {
                this.fixTransposeDimensions(out);
                out = LibMatrixReorg.transpose(out, new MatrixBlock(out.getNumColumns(), out.getNumRows(), false));
            }
            out.examSparsity();
        }
        catch (Exception ex) {
            throw new DMLRuntimeException(ex);
        }
        return out;
    }

    public static boolean hasMatrixSideInput(ArrayList<MatrixBlock> inputs) {
        return IntStream.range(1, inputs.size()).mapToObj(i -> (MatrixBlock)inputs.get(i)).anyMatch(in -> in.getNumColumns() > 1);
    }

    private static int getMinColsMatrixSideInputs(ArrayList<MatrixBlock> inputs) {
        return IntStream.range(1, inputs.size()).map(i -> ((MatrixBlock)inputs.get(i)).getNumColumns()).filter(ncol -> ncol > 1).min().orElse(1);
    }

    private void allocateOutputMatrix(int m, int n, int n2, MatrixBlock out) {
        switch (this._type) {
            case NO_AGG: {
                out.reset(m, n, false);
                break;
            }
            case NO_AGG_B1: {
                out.reset(m, n2, false);
                break;
            }
            case FULL_AGG: {
                out.reset(1, 1, false);
                break;
            }
            case ROW_AGG: {
                out.reset(m, 1 + (this._cbind0 ? 1 : 0), false);
                break;
            }
            case COL_AGG: {
                out.reset(1, n, false);
                break;
            }
            case COL_AGG_T: {
                out.reset(n, 1, false);
                break;
            }
            case COL_AGG_B1: {
                out.reset(n2, n, false);
                break;
            }
            case COL_AGG_B1_T: {
                out.reset(n, n2, false);
            }
        }
        out.allocateDenseBlock();
    }

    private void fixTransposeDimensions(MatrixBlock out) {
        int rlen = out.getNumRows();
        out.setNumRows(out.getNumColumns());
        out.setNumColumns(rlen);
    }

    private void executeDense(double[] a, SpoofOperator.SideInput[] b, double[] scalars, double[] c, int n, int rl, int ru) {
        if (a == null) {
            return;
        }
        int i = rl;
        int aix = rl * n;
        while (i < ru) {
            this.genexec(a, aix, b, scalars, c, n, i);
            ++i;
            aix += n;
        }
    }

    private void executeSparse(SparseBlock sblock, SpoofOperator.SideInput[] b, double[] scalars, double[] c, int n, int rl, int ru) {
        SparseRowVector empty = new SparseRowVector(1);
        for (int i = rl; i < ru; ++i) {
            if (sblock != null && !sblock.isEmpty(i)) {
                double[] avals = sblock.values(i);
                int[] aix = sblock.indexes(i);
                int apos = sblock.pos(i);
                int alen = sblock.size(i);
                this.genexec(avals, aix, apos, b, scalars, c, alen, n, i);
                continue;
            }
            this.genexec(((SparseRow)empty).values(), ((SparseRow)empty).indexes(), 0, b, scalars, c, 0, n, i);
        }
    }

    private void executeCompressed(CompressedMatrixBlock a, SpoofOperator.SideInput[] b, double[] scalars, double[] c, int n, int rl, int ru) {
        if (a.isEmptyBlock(false)) {
            return;
        }
        if (!a.isInSparseFormat()) {
            Iterator<double[]> iter = a.getDenseRowIterator(rl, ru);
            int i = rl;
            while (iter.hasNext()) {
                this.genexec(iter.next(), 0, b, scalars, c, n, i);
                ++i;
            }
        } else {
            Iterator<SparseRow> iter = a.getSparseRowIterator(rl, ru);
            SparseRowVector empty = new SparseRowVector(1);
            int i = rl;
            while (iter.hasNext()) {
                SparseRow row = iter.next();
                if (!row.isEmpty()) {
                    this.genexec(row.values(), row.indexes(), 0, b, scalars, c, row.size(), n, i);
                } else {
                    this.genexec(((SparseRow)empty).values(), ((SparseRow)empty).indexes(), 0, b, scalars, c, 0, n, i);
                }
                ++i;
            }
        }
    }

    protected abstract void genexec(double[] var1, int var2, SpoofOperator.SideInput[] var3, double[] var4, double[] var5, int var6, int var7);

    protected abstract void genexec(double[] var1, int[] var2, int var3, SpoofOperator.SideInput[] var4, double[] var5, double[] var6, int var7, int var8, int var9);

    private class ParExecTask
    implements Callable<Long> {
        private final MatrixBlock _a;
        private final SpoofOperator.SideInput[] _b;
        private final MatrixBlock _c;
        private final double[] _scalars;
        private final int _clen;
        private final int _clen2;
        private final int _rl;
        private final int _ru;

        protected ParExecTask(MatrixBlock a, SpoofOperator.SideInput[] b, MatrixBlock c, double[] scalars, int clen, int clen2, int rl, int ru) {
            this._a = a;
            this._b = b;
            this._c = c;
            this._scalars = scalars;
            this._clen = clen;
            this._clen2 = clen2;
            this._rl = rl;
            this._ru = ru;
        }

        @Override
        public Long call() throws DMLRuntimeException {
            if (SpoofRowwise.this._reqVectMem > 0) {
                LibSpoofPrimitives.setupThreadLocalMemory(SpoofRowwise.this._reqVectMem, this._clen, this._clen2);
            }
            if (this._a instanceof CompressedMatrixBlock) {
                SpoofRowwise.this.executeCompressed((CompressedMatrixBlock)this._a, this._b, this._scalars, this._c.getDenseBlock(), this._clen, this._rl, this._ru);
            } else if (!this._a.isInSparseFormat()) {
                SpoofRowwise.this.executeDense(this._a.getDenseBlock(), this._b, this._scalars, this._c.getDenseBlock(), this._clen, this._rl, this._ru);
            } else {
                SpoofRowwise.this.executeSparse(this._a.getSparseBlock(), this._b, this._scalars, this._c.getDenseBlock(), this._clen, this._rl, this._ru);
            }
            if (SpoofRowwise.this._reqVectMem > 0) {
                LibSpoofPrimitives.cleanupThreadLocalMemory();
            }
            return this._c.recomputeNonZeros(this._rl, this._ru - 1, 0, this._c.getNumColumns() - 1);
        }
    }

    private class ParColAggTask
    implements Callable<double[]> {
        private final MatrixBlock _a;
        private final SpoofOperator.SideInput[] _b;
        private final double[] _scalars;
        private final int _clen;
        private final int _clen2;
        private final int _rl;
        private final int _ru;

        protected ParColAggTask(MatrixBlock a, SpoofOperator.SideInput[] b, double[] scalars, int clen, int clen2, int rl, int ru) {
            this._a = a;
            this._b = b;
            this._scalars = scalars;
            this._clen = clen;
            this._clen2 = clen2;
            this._rl = rl;
            this._ru = ru;
        }

        @Override
        public double[] call() throws DMLRuntimeException {
            if (SpoofRowwise.this._reqVectMem > 0) {
                LibSpoofPrimitives.setupThreadLocalMemory(SpoofRowwise.this._reqVectMem, this._clen, this._clen2);
            }
            double[] c = new double[this._clen2 > 0 ? this._clen * this._clen2 : this._clen];
            if (this._a instanceof CompressedMatrixBlock) {
                SpoofRowwise.this.executeCompressed((CompressedMatrixBlock)this._a, this._b, this._scalars, c, this._clen, this._rl, this._ru);
            } else if (!this._a.isInSparseFormat()) {
                SpoofRowwise.this.executeDense(this._a.getDenseBlock(), this._b, this._scalars, c, this._clen, this._rl, this._ru);
            } else {
                SpoofRowwise.this.executeSparse(this._a.getSparseBlock(), this._b, this._scalars, c, this._clen, this._rl, this._ru);
            }
            if (SpoofRowwise.this._reqVectMem > 0) {
                LibSpoofPrimitives.cleanupThreadLocalMemory();
            }
            return c;
        }
    }

    public static enum RowType {
        NO_AGG,
        NO_AGG_B1,
        FULL_AGG,
        ROW_AGG,
        COL_AGG,
        COL_AGG_T,
        COL_AGG_B1,
        COL_AGG_B1_T;


        public boolean isColumnAgg() {
            return this == COL_AGG || this == COL_AGG_T || this == COL_AGG_B1 || this == COL_AGG_B1_T;
        }

        public boolean isRowTypeB1() {
            return this == NO_AGG_B1 || this == COL_AGG_B1 || this == COL_AGG_B1_T;
        }

        public boolean isRowTypeB1ColumnAgg() {
            return this == COL_AGG_B1 || this == COL_AGG_B1_T;
        }
    }
}

