/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.instructions.mr;

import java.util.ArrayList;
import org.apache.sysml.lops.PartialAggregate;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.functionobjects.KahanPlus;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.mr.BinaryMRInstructionBase;
import org.apache.sysml.runtime.instructions.mr.IDistributedCacheConsumer;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.data.MatrixValue;
import org.apache.sysml.runtime.matrix.data.OperationsOnMatrixValues;
import org.apache.sysml.runtime.matrix.mapred.CachedValueMap;
import org.apache.sysml.runtime.matrix.mapred.DistributedCacheInput;
import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue;
import org.apache.sysml.runtime.matrix.mapred.MRBaseForCommonInstructions;
import org.apache.sysml.runtime.matrix.operators.AggregateOperator;
import org.apache.sysml.runtime.matrix.operators.Operator;

public class GroupedAggregateMInstruction
extends BinaryMRInstructionBase
implements IDistributedCacheConsumer {
    private int _ngroups = -1;

    private GroupedAggregateMInstruction(Operator op, byte in1, byte in2, byte out, int ngroups, String istr) {
        super(op, in1, in2, out);
        this._ngroups = ngroups;
    }

    public static GroupedAggregateMInstruction parseInstruction(String str) throws DMLRuntimeException {
        String[] parts = InstructionUtils.getInstructionParts(str);
        InstructionUtils.checkNumFields(parts, 5);
        byte in1 = Byte.parseByte(parts[1]);
        byte in2 = Byte.parseByte(parts[2]);
        byte out = Byte.parseByte(parts[3]);
        int ngroups = Integer.parseInt(parts[4]);
        AggregateOperator op = new AggregateOperator(0.0, KahanPlus.getKahanPlusFnObject(), true, PartialAggregate.CorrectionLocationType.LASTCOLUMN);
        return new GroupedAggregateMInstruction(op, in1, in2, out, ngroups, str);
    }

    @Override
    public void processInstruction(Class<? extends MatrixValue> valueClass, CachedValueMap cachedValues, IndexedMatrixValue tempValue, IndexedMatrixValue zeroInput, int blockRowFactor, int blockColFactor) throws DMLRuntimeException {
        ArrayList<IndexedMatrixValue> blkList = cachedValues.get(this.input1);
        if (blkList == null) {
            return;
        }
        for (IndexedMatrixValue in1 : blkList) {
            if (in1 == null) continue;
            DistributedCacheInput dcInput = MRBaseForCommonInstructions.dcValues.get(this.input2);
            MatrixIndexes ix = in1.getIndexes();
            MatrixBlock groups = (MatrixBlock)dcInput.getDataBlock((int)ix.getRowIndex(), 1).getValue();
            int brlen = dcInput.getNumRowsPerBlock();
            int bclen = dcInput.getNumColsPerBlock();
            ArrayList<IndexedMatrixValue> outlist = new ArrayList<IndexedMatrixValue>();
            OperationsOnMatrixValues.performMapGroupedAggregate(this.getOperator(), in1, groups, this._ngroups, brlen, bclen, outlist);
            for (IndexedMatrixValue out : outlist) {
                cachedValues.add(this.output, out);
            }
        }
    }

    @Override
    public boolean isDistCacheOnlyIndex(String inst, byte index) {
        return index == this.input2 && index != this.input1;
    }

    @Override
    public void addDistCacheIndex(String inst, ArrayList<Byte> indexes) {
        indexes.add(this.input2);
    }

    public void computeOutputCharacteristics(MatrixCharacteristics mcIn, MatrixCharacteristics mcOut) {
        mcOut.set(this._ngroups, mcIn.getCols(), mcIn.getRowsPerBlock(), mcIn.getColsPerBlock());
    }
}

