/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.hops.codegen.cplan;

import java.util.ArrayList;
import java.util.Arrays;
import org.apache.commons.collections.CollectionUtils;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.codegen.SpoofFusedOp;
import org.apache.sysml.hops.codegen.cplan.CNode;
import org.apache.sysml.hops.codegen.cplan.CNodeData;
import org.apache.sysml.hops.codegen.cplan.CNodeTpl;
import org.apache.sysml.runtime.util.UtilFunctions;

public class CNodeMultiAgg
extends CNodeTpl {
    private static final String TEMPLATE = "package codegen;\nimport org.apache.sysml.runtime.codegen.LibSpoofPrimitives;\nimport org.apache.sysml.runtime.codegen.SpoofCellwise;\nimport org.apache.sysml.runtime.codegen.SpoofCellwise.AggOp;\nimport org.apache.sysml.runtime.codegen.SpoofMultiAggregate;\nimport org.apache.sysml.runtime.codegen.SpoofOperator.SideInput;\nimport org.apache.commons.math3.util.FastMath;\n\npublic final class %TMP% extends SpoofMultiAggregate { \n  public %TMP%() {\n    super(%SPARSE_SAFE%, %AGG_OP%);\n  }\n  protected void genexec(double a, SideInput[] b, double[] scalars, double[] c, int m, int n, int rix, int cix) { \n%BODY_dense%  }\n}\n";
    private static final String TEMPLATE_OUT_SUM = "    c[%IX%] += %IN%;\n";
    private static final String TEMPLATE_OUT_SUMSQ = "    c[%IX%] += %IN% * %IN%;\n";
    private static final String TEMPLATE_OUT_MIN = "    c[%IX%] = Math.min(c[%IX%], %IN%);\n";
    private static final String TEMPLATE_OUT_MAX = "    c[%IX%] = Math.max(c[%IX%], %IN%);\n";
    private ArrayList<CNode> _outputs = null;
    private ArrayList<Hop.AggOp> _aggOps = null;
    private ArrayList<Hop> _roots = null;
    private boolean _sparseSafe = false;

    public CNodeMultiAgg(ArrayList<CNode> inputs, ArrayList<CNode> outputs) {
        super(inputs, null);
        this._outputs = outputs;
    }

    public ArrayList<CNode> getOutputs() {
        return this._outputs;
    }

    @Override
    public void resetVisitStatusOutputs() {
        for (CNode output : this._outputs) {
            output.resetVisitStatus();
        }
    }

    public void setAggOps(ArrayList<Hop.AggOp> aggOps) {
        this._aggOps = aggOps;
        this._hash = 0;
    }

    public ArrayList<Hop.AggOp> getAggOps() {
        return this._aggOps;
    }

    public void setRootNodes(ArrayList<Hop> roots) {
        this._roots = roots;
    }

    public ArrayList<Hop> getRootNodes() {
        return this._roots;
    }

    public void setSparseSafe(boolean flag) {
        this._sparseSafe = flag;
    }

    public boolean isSparseSafe() {
        return this._sparseSafe;
    }

    @Override
    public void renameInputs() {
        this.rRenameDataNode(this._outputs, (CNode)this._inputs.get(0), "a");
        this.renameInputs(this._outputs, this._inputs, 1);
    }

    @Override
    public String codegen(boolean sparse) {
        String tmp = TEMPLATE;
        StringBuilder sb = new StringBuilder();
        for (CNode out : this._outputs) {
            sb.append(out.codegen(false));
        }
        for (CNode out : this._outputs) {
            out.resetGenerated();
        }
        for (int i = 0; i < this._outputs.size(); ++i) {
            CNode out;
            out = this._outputs.get(i);
            String tmpOut = this.getAggTemplate(i);
            String varName = out instanceof CNodeData && ((CNodeData)out).getHopID() == ((CNodeData)this._inputs.get(0)).getHopID() ? "a" : out.getVarname();
            tmpOut = tmpOut.replace("%IN%", varName);
            tmpOut = tmpOut.replace("%IX%", String.valueOf(i));
            sb.append(tmpOut);
        }
        tmp = tmp.replace("%TMP%", this.createVarname());
        tmp = tmp.replace("%BODY_dense%", sb.toString());
        String aggList = "";
        for (Hop.AggOp aggOp : this._aggOps) {
            aggList = aggList + (!aggList.isEmpty() ? "," : "");
            aggList = aggList + "AggOp." + aggOp.name();
        }
        tmp = tmp.replace("%AGG_OP%", aggList);
        tmp = tmp.replace("%SPARSE_SAFE%", String.valueOf(this.isSparseSafe()));
        return tmp;
    }

    @Override
    public void setOutputDims() {
    }

    @Override
    public SpoofFusedOp.SpoofOutputDimsType getOutputDimType() {
        return SpoofFusedOp.SpoofOutputDimsType.MULTI_SCALAR;
    }

    @Override
    public CNodeTpl clone() {
        CNodeMultiAgg ret = new CNodeMultiAgg((ArrayList<CNode>)this._inputs, this._outputs);
        ret.setAggOps(this.getAggOps());
        return ret;
    }

    @Override
    public int hashCode() {
        if (this._hash == 0) {
            int h = super.hashCode();
            for (int i = 0; i < this._outputs.size(); ++i) {
                h = UtilFunctions.intHashCode(h, UtilFunctions.intHashCode(this._outputs.get(i).hashCode(), this._aggOps.get(i).hashCode()));
            }
            this._hash = h;
        }
        return this._hash;
    }

    @Override
    public boolean equals(Object o) {
        if (!(o instanceof CNodeMultiAgg)) {
            return false;
        }
        CNodeMultiAgg that = (CNodeMultiAgg)o;
        return super.equals(o) && CollectionUtils.isEqualCollection(this._aggOps, that._aggOps) && CNodeMultiAgg.equalInputReferences(this._outputs, that._outputs, (ArrayList<CNode>)this._inputs, (ArrayList<CNode>)that._inputs);
    }

    @Override
    public String getTemplateInfo() {
        StringBuilder sb = new StringBuilder();
        sb.append("SPOOF MULTIAGG [aggOps=");
        sb.append(Arrays.toString((Object[])this._aggOps.toArray(new Hop.AggOp[0])));
        sb.append("]");
        return sb.toString();
    }

    private String getAggTemplate(int pos) {
        switch (this._aggOps.get(pos)) {
            case SUM: {
                return TEMPLATE_OUT_SUM;
            }
            case SUM_SQ: {
                return TEMPLATE_OUT_SUMSQ;
            }
            case MIN: {
                return TEMPLATE_OUT_MIN;
            }
            case MAX: {
                return TEMPLATE_OUT_MAX;
            }
        }
        return null;
    }
}

