/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.instructions.mr;

import java.util.ArrayList;
import java.util.Arrays;
import org.apache.sysml.lops.PartialAggregate;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.functionobjects.ReduceAll;
import org.apache.sysml.runtime.functionobjects.ReduceCol;
import org.apache.sysml.runtime.functionobjects.ReduceRow;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.mr.BinaryInstruction;
import org.apache.sysml.runtime.instructions.mr.IDistributedCacheConsumer;
import org.apache.sysml.runtime.instructions.mr.MRInstruction;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.LibMatrixOuterAgg;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.data.MatrixValue;
import org.apache.sysml.runtime.matrix.data.OperationsOnMatrixValues;
import org.apache.sysml.runtime.matrix.mapred.CachedValueMap;
import org.apache.sysml.runtime.matrix.mapred.DistributedCacheInput;
import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue;
import org.apache.sysml.runtime.matrix.mapred.MRBaseForCommonInstructions;
import org.apache.sysml.runtime.matrix.operators.AggregateOperator;
import org.apache.sysml.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysml.runtime.matrix.operators.BinaryOperator;

public class UaggOuterChainInstruction
extends BinaryInstruction
implements IDistributedCacheConsumer {
    private AggregateUnaryOperator _uaggOp = null;
    private AggregateOperator _aggOp = null;
    private BinaryOperator _bOp = null;
    private MatrixValue _tmpVal1 = null;
    private MatrixValue _tmpVal2 = null;
    private double[] _bv = null;
    private int[] _bvi = null;

    private UaggOuterChainInstruction(BinaryOperator bop, AggregateUnaryOperator uaggop, AggregateOperator aggop, byte in1, byte in2, byte out, String istr) {
        super(MRInstruction.MRType.UaggOuterChain, null, in1, in2, out, istr);
        this._uaggOp = uaggop;
        this._aggOp = aggop;
        this._bOp = bop;
        this._tmpVal1 = new MatrixBlock();
        this._tmpVal2 = new MatrixBlock();
        this.instString = istr;
    }

    public static UaggOuterChainInstruction parseInstruction(String str) throws DMLRuntimeException {
        InstructionUtils.checkNumFields(str, 5);
        String[] parts = InstructionUtils.getInstructionParts(str);
        AggregateUnaryOperator uaggop = InstructionUtils.parseBasicAggregateUnaryOperator(parts[1]);
        BinaryOperator bop = InstructionUtils.parseBinaryOperator(parts[2]);
        byte in1 = Byte.parseByte(parts[3]);
        byte in2 = Byte.parseByte(parts[4]);
        byte out = Byte.parseByte(parts[5]);
        String aopcode = InstructionUtils.deriveAggregateOperatorOpcode(parts[1]);
        PartialAggregate.CorrectionLocationType corrLoc = InstructionUtils.deriveAggregateOperatorCorrectionLocation(parts[1]);
        String corrExists = corrLoc != PartialAggregate.CorrectionLocationType.NONE ? "true" : "false";
        AggregateOperator aop = InstructionUtils.parseAggregateOperator(aopcode, corrExists, corrLoc.toString());
        return new UaggOuterChainInstruction(bop, uaggop, aop, in1, in2, out, str);
    }

    public void computeOutputCharacteristics(MatrixCharacteristics mcIn1, MatrixCharacteristics mcIn2, MatrixCharacteristics mcOut) {
        if (this._uaggOp.indexFn instanceof ReduceAll) {
            mcOut.set(1L, 1L, mcIn1.getRowsPerBlock(), mcIn2.getColsPerBlock());
        } else if (this._uaggOp.indexFn instanceof ReduceCol) {
            mcOut.set(mcIn1.getRows(), 1L, mcIn1.getRowsPerBlock(), mcIn2.getColsPerBlock());
        } else if (this._uaggOp.indexFn instanceof ReduceRow) {
            mcOut.set(1L, mcIn2.getCols(), mcIn1.getRowsPerBlock(), mcIn2.getColsPerBlock());
        }
    }

    @Override
    public void processInstruction(Class<? extends MatrixValue> valueClass, CachedValueMap cachedValues, IndexedMatrixValue tempValue, IndexedMatrixValue zeroInput, int blockRowFactor, int blockColFactor) throws DMLRuntimeException {
        ArrayList<IndexedMatrixValue> blkList = null;
        boolean rightCached = this._uaggOp.indexFn instanceof ReduceCol || this._uaggOp.indexFn instanceof ReduceAll || !LibMatrixOuterAgg.isSupportedUaggOp(this._uaggOp, this._bOp);
        blkList = rightCached ? cachedValues.get(this.input1) : cachedValues.get(this.input2);
        if (blkList == null) {
            return;
        }
        for (IndexedMatrixValue imv : blkList) {
            if (imv == null) continue;
            MatrixIndexes in1Ix = imv.getIndexes();
            MatrixValue in1Val = imv.getValue();
            IndexedMatrixValue iout = cachedValues.holdPlace(this.output, valueClass);
            MatrixIndexes outIx = iout.getIndexes();
            MatrixValue outVal = iout.getValue();
            MatrixBlock corr = null;
            byte dcInputIx = rightCached ? this.input2 : this.input1;
            DistributedCacheInput dcInput = MRBaseForCommonInstructions.dcValues.get(dcInputIx);
            if (LibMatrixOuterAgg.isSupportedUaggOp(this._uaggOp, this._bOp)) {
                if (LibMatrixOuterAgg.isRowIndexMax(this._uaggOp) || LibMatrixOuterAgg.isRowIndexMin(this._uaggOp)) {
                    if (this._bv == null) {
                        this._bv = rightCached ? dcInput.getRowVectorArray() : dcInput.getColumnVectorArray();
                        this._bvi = LibMatrixOuterAgg.prepareRowIndices(this._bv.length, this._bv, this._bOp, this._uaggOp);
                    }
                } else if (this._bv == null) {
                    this._bv = rightCached ? dcInput.getRowVectorArray() : dcInput.getColumnVectorArray();
                    Arrays.sort(this._bv);
                }
                LibMatrixOuterAgg.resetOutputMatrix(in1Ix, (MatrixBlock)in1Val, outIx, (MatrixBlock)outVal, this._uaggOp);
                LibMatrixOuterAgg.aggregateMatrix((MatrixBlock)in1Val, (MatrixBlock)outVal, this._bv, this._bvi, this._bOp, this._uaggOp);
                continue;
            }
            long in2_cols = dcInput.getNumCols();
            long in2_colBlocks = (long)Math.ceil((double)in2_cols / (double)dcInput.getNumColsPerBlock());
            int bidx = 1;
            while ((long)bidx <= in2_colBlocks) {
                IndexedMatrixValue imv2 = dcInput.getDataBlock(1, bidx);
                MatrixValue in2Val = imv2.getValue();
                OperationsOnMatrixValues.performBinaryIgnoreIndexes(in1Val, in2Val, this._tmpVal1, this._bOp);
                OperationsOnMatrixValues.performAggregateUnary(in1Ix, this._tmpVal1, outIx, this._tmpVal2, this._uaggOp, blockRowFactor, blockColFactor);
                if (corr == null) {
                    outVal.reset(this._tmpVal2.getNumRows(), this._tmpVal2.getNumColumns(), false);
                    corr = new MatrixBlock(this._tmpVal2.getNumRows(), this._tmpVal2.getNumColumns(), false);
                }
                if (this._aggOp.correctionExists) {
                    OperationsOnMatrixValues.incrementalAggregation(outVal, corr, this._tmpVal2, this._aggOp, true);
                } else {
                    OperationsOnMatrixValues.incrementalAggregation(outVal, null, this._tmpVal2, this._aggOp, true);
                }
                ++bidx;
            }
        }
    }

    @Override
    public boolean isDistCacheOnlyIndex(String inst, byte index) {
        if (this._uaggOp.indexFn instanceof ReduceCol || this._uaggOp.indexFn instanceof ReduceAll || !LibMatrixOuterAgg.isSupportedUaggOp(this._uaggOp, this._bOp)) {
            return index == this.input2 && index != this.input1;
        }
        return index == this.input1 && index != this.input2;
    }

    @Override
    public void addDistCacheIndex(String inst, ArrayList<Byte> indexes) {
        if (this._uaggOp.indexFn instanceof ReduceCol || this._uaggOp.indexFn instanceof ReduceAll || !LibMatrixOuterAgg.isSupportedUaggOp(this._uaggOp, this._bOp)) {
            indexes.add(this.input2);
        } else {
            indexes.add(this.input1);
        }
    }
}

