/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.hops.codegen.template;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.sysml.hops.AggBinaryOp;
import org.apache.sysml.hops.AggUnaryOp;
import org.apache.sysml.hops.BinaryOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.IndexingOp;
import org.apache.sysml.hops.LiteralOp;
import org.apache.sysml.hops.ParameterizedBuiltinOp;
import org.apache.sysml.hops.TernaryOp;
import org.apache.sysml.hops.UnaryOp;
import org.apache.sysml.hops.codegen.cplan.CNode;
import org.apache.sysml.hops.codegen.cplan.CNodeBinary;
import org.apache.sysml.hops.codegen.cplan.CNodeData;
import org.apache.sysml.hops.codegen.cplan.CNodeRow;
import org.apache.sysml.hops.codegen.cplan.CNodeTernary;
import org.apache.sysml.hops.codegen.cplan.CNodeTpl;
import org.apache.sysml.hops.codegen.cplan.CNodeUnary;
import org.apache.sysml.hops.codegen.template.CPlanMemoTable;
import org.apache.sysml.hops.codegen.template.TemplateBase;
import org.apache.sysml.hops.codegen.template.TemplateCell;
import org.apache.sysml.hops.codegen.template.TemplateUtils;
import org.apache.sysml.hops.rewrite.HopRewriteUtils;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.runtime.matrix.data.Pair;

public class TemplateRow
extends TemplateBase {
    private static final Hop.AggOp[] SUPPORTED_ROW_AGG = new Hop.AggOp[]{Hop.AggOp.SUM, Hop.AggOp.MIN, Hop.AggOp.MAX};
    private static final Hop.OpOp1[] SUPPORTED_VECT_UNARY = new Hop.OpOp1[]{Hop.OpOp1.EXP, Hop.OpOp1.SQRT, Hop.OpOp1.LOG, Hop.OpOp1.ABS, Hop.OpOp1.ROUND, Hop.OpOp1.CEIL, Hop.OpOp1.FLOOR, Hop.OpOp1.SIGN};
    private static final Hop.OpOp2[] SUPPORTED_VECT_BINARY = new Hop.OpOp2[]{Hop.OpOp2.MULT, Hop.OpOp2.DIV, Hop.OpOp2.MINUS, Hop.OpOp2.PLUS, Hop.OpOp2.POW, Hop.OpOp2.MIN, Hop.OpOp2.MAX, Hop.OpOp2.EQUAL, Hop.OpOp2.NOTEQUAL, Hop.OpOp2.LESS, Hop.OpOp2.LESSEQUAL, Hop.OpOp2.GREATER, Hop.OpOp2.GREATEREQUAL};

    public TemplateRow() {
        super(TemplateBase.TemplateType.RowTpl);
    }

    public TemplateRow(boolean closed) {
        super(TemplateBase.TemplateType.RowTpl, closed);
    }

    @Override
    public boolean open(Hop hop) {
        return hop instanceof BinaryOp && hop.getInput().get(0).getDim2() > 1L && hop.getInput().get(1).getDim2() == 1L && TemplateCell.isValidOperation(hop) || hop instanceof AggBinaryOp && hop.getDim2() == 1L && hop.getInput().get(0).getDim1() > 1L && hop.getInput().get(0).getDim2() > 1L || hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getDirection() != Hop.Direction.RowCol && hop.getInput().get(0).getDim1() > 1L && hop.getInput().get(0).getDim2() > 1L;
    }

    @Override
    public boolean fuse(Hop hop, Hop input) {
        return !this.isClosed() && (hop instanceof BinaryOp && TemplateUtils.isOperationSupported(hop) && (HopRewriteUtils.isBinaryMatrixColVectorOperation(hop) || HopRewriteUtils.isBinaryMatrixScalarOperation(hop)) || (hop instanceof UnaryOp || hop instanceof ParameterizedBuiltinOp) && TemplateCell.isValidOperation(hop) || hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getDirection() != Hop.Direction.RowCol || hop instanceof AggBinaryOp && hop.getDim1() > 1L && HopRewriteUtils.isTransposeOperation(hop.getInput().get(0)));
    }

    @Override
    public boolean merge(Hop hop, Hop input) {
        return !this.isClosed() && (hop instanceof BinaryOp && input.getDim2() == 1L && TemplateUtils.isOperationSupported(hop) || hop instanceof AggBinaryOp && input.getDim2() == 1L && HopRewriteUtils.isTransposeOperation(hop.getInput().get(0)));
    }

    @Override
    public TemplateBase.CloseType close(Hop hop) {
        if (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getDirection() == Hop.Direction.Col || hop instanceof AggBinaryOp && HopRewriteUtils.isTransposeOperation(hop.getInput().get(0))) {
            return TemplateBase.CloseType.CLOSED_VALID;
        }
        return TemplateBase.CloseType.OPEN;
    }

    @Override
    public Pair<Hop[], CNodeTpl> constructCplan(Hop hop, CPlanMemoTable memo, boolean compileLiterals) {
        HashSet<Hop> inHops = new HashSet<Hop>();
        HashMap<String, Hop> inHops2 = new HashMap<String, Hop>();
        HashMap<Long, CNode> tmp = new HashMap<Long, CNode>();
        hop.resetVisitStatus();
        this.rConstructCplan(hop, memo, tmp, inHops, inHops2, compileLiterals);
        hop.resetVisitStatus();
        List<Hop> sinHops = inHops.stream().filter(h -> !h.getDataType().isScalar() || !((CNode)tmp.get(h.getHopID())).isLiteral()).sorted(new HopInputComparator(inHops2.get("X"))).collect(Collectors.toList());
        ArrayList<CNode> inputs = new ArrayList<CNode>();
        for (Hop in : sinHops) {
            inputs.add(tmp.get(in.getHopID()));
        }
        CNode output = tmp.get(hop.getHopID());
        CNodeRow tpl = new CNodeRow(inputs, output);
        tpl.setRowType(TemplateUtils.getRowType(hop, (Hop)sinHops.get(0)));
        tpl.setNumVectorIntermediates(TemplateUtils.countVectorIntermediates(output, new HashSet<Long>()));
        return new Pair<Hop[], CNodeTpl>(sinHops.toArray(new Hop[0]), tpl);
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    private void rConstructCplan(Hop hop, CPlanMemoTable memo, HashMap<Long, CNode> tmp, HashSet<Hop> inHops, HashMap<String, Hop> inHops2, boolean compileLiterals) {
        CNode cdata2;
        CNode cdata1;
        if (tmp.containsKey(hop.getHopID())) {
            return;
        }
        CPlanMemoTable.MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateBase.TemplateType.RowTpl);
        for (int i = 0; i < hop.getInput().size(); ++i) {
            Hop c = hop.getInput().get(i);
            if (me.isPlanRef(i)) {
                this.rConstructCplan(c, memo, tmp, inHops, inHops2, compileLiterals);
                continue;
            }
            CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals);
            tmp.put(c.getHopID(), cdata);
            inHops.add(c);
        }
        CNode out = null;
        if (hop instanceof AggUnaryOp) {
            cdata1 = tmp.get(hop.getInput().get(0).getHopID());
            if (((AggUnaryOp)hop).getDirection() == Hop.Direction.Row && HopRewriteUtils.isAggUnaryOp(hop, SUPPORTED_ROW_AGG)) {
                if (hop.getInput().get(0).getDim2() == 1L) {
                    out = cdata1.getDataType() == Expression.DataType.SCALAR ? cdata1 : new CNodeUnary(cdata1, CNodeUnary.UnaryType.LOOKUP_R);
                } else {
                    String opcode = "ROW_" + ((AggUnaryOp)hop).getOp().name().toUpperCase() + "S";
                    out = new CNodeUnary(cdata1, CNodeUnary.UnaryType.valueOf(opcode));
                    inHops2.put("X", hop.getInput().get(0));
                }
            } else if (((AggUnaryOp)hop).getDirection() == Hop.Direction.Col && ((AggUnaryOp)hop).getOp() == Hop.AggOp.SUM) {
                out = cdata1 instanceof CNodeBinary && ((CNodeBinary)cdata1).getType().isVectorScalarPrimitive() ? new CNodeBinary(cdata1.getInput().get(0), cdata1.getInput().get(1), ((CNodeBinary)cdata1).getType().getVectorAddPrimitive()) : cdata1;
            }
        } else if (hop instanceof AggBinaryOp) {
            cdata1 = tmp.get(hop.getInput().get(0).getHopID());
            cdata2 = tmp.get(hop.getInput().get(1).getHopID());
            if (HopRewriteUtils.isTransposeOperation(hop.getInput().get(0))) {
                cdata1 = TemplateUtils.skipTranspose(cdata1, hop.getInput().get(0), tmp, compileLiterals);
                inHops.remove(hop.getInput().get(0));
                inHops.add(hop.getInput().get(0).getInput().get(0));
                out = new CNodeBinary(cdata1, cdata2, CNodeBinary.BinType.VECT_MULT_ADD);
            } else if (hop.getInput().get(0).getDim2() == 1L && hop.getInput().get(1).getDim2() == 1L) {
                out = new CNodeBinary(cdata1.getDataType() == Expression.DataType.SCALAR ? cdata1 : new CNodeUnary(cdata1, CNodeUnary.UnaryType.LOOKUP0), cdata2.getDataType() == Expression.DataType.SCALAR ? cdata2 : new CNodeUnary(cdata2, CNodeUnary.UnaryType.LOOKUP0), CNodeBinary.BinType.MULT);
            } else {
                out = new CNodeBinary(cdata1, cdata2, CNodeBinary.BinType.DOT_PRODUCT);
                inHops2.put("X", hop.getInput().get(0));
            }
        } else if (hop instanceof UnaryOp) {
            cdata1 = tmp.get(hop.getInput().get(0).getHopID());
            if (hop.getInput().get(0).getDim1() > 1L && hop.getInput().get(0).getDim2() > 1L) {
                if (!HopRewriteUtils.isUnary(hop, SUPPORTED_VECT_UNARY)) throw new RuntimeException("Unsupported unary matrix operation: " + ((UnaryOp)hop).getOp().name());
                String opname = "VECT_" + ((UnaryOp)hop).getOp().name();
                out = new CNodeUnary(cdata1, CNodeUnary.UnaryType.valueOf(opname));
            } else {
                cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
                String primitiveOpName = ((UnaryOp)hop).getOp().toString();
                out = new CNodeUnary(cdata1, CNodeUnary.UnaryType.valueOf(primitiveOpName));
            }
        } else if (hop instanceof BinaryOp) {
            cdata1 = tmp.get(hop.getInput().get(0).getHopID());
            cdata2 = tmp.get(hop.getInput().get(1).getHopID());
            if (hop.getInput().get(0).getDim1() > 1L && hop.getInput().get(0).getDim2() > 1L) {
                if (!HopRewriteUtils.isBinary(hop, SUPPORTED_VECT_BINARY)) throw new RuntimeException("Unsupported binary matrix operation: " + ((BinaryOp)hop).getOp().name());
                String opname = "VECT_" + ((BinaryOp)hop).getOp().name() + "_SCALAR";
                if (TemplateUtils.isColVector(cdata2)) {
                    cdata2 = new CNodeUnary(cdata2, CNodeUnary.UnaryType.LOOKUP_R);
                }
                out = new CNodeBinary(cdata1, cdata2, CNodeBinary.BinType.valueOf(opname));
            } else {
                String primitiveOpName = ((BinaryOp)hop).getOp().toString();
                if (TemplateUtils.isColVector(cdata1)) {
                    cdata1 = new CNodeUnary(cdata1, CNodeUnary.UnaryType.LOOKUP_R);
                }
                if (TemplateUtils.isColVector(cdata2)) {
                    cdata2 = new CNodeUnary(cdata2, CNodeUnary.UnaryType.LOOKUP_R);
                }
                out = new CNodeBinary(cdata1, cdata2, CNodeBinary.BinType.valueOf(primitiveOpName));
            }
        } else if (hop instanceof TernaryOp) {
            TernaryOp top = (TernaryOp)hop;
            CNode cdata12 = tmp.get(hop.getInput().get(0).getHopID());
            CNode cdata22 = tmp.get(hop.getInput().get(1).getHopID());
            CNode cdata3 = tmp.get(hop.getInput().get(2).getHopID());
            cdata12 = TemplateUtils.wrapLookupIfNecessary(cdata12, hop.getInput().get(0));
            cdata3 = TemplateUtils.wrapLookupIfNecessary(cdata3, hop.getInput().get(2));
            out = new CNodeTernary(cdata12, cdata22, cdata3, CNodeTernary.TernaryType.valueOf(top.getOp().toString()));
        } else if (hop instanceof ParameterizedBuiltinOp) {
            cdata1 = tmp.get(((ParameterizedBuiltinOp)hop).getTargetHop().getHopID());
            cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
            cdata2 = tmp.get(((ParameterizedBuiltinOp)hop).getParameterHop("pattern").getHopID());
            CNode cdata3 = tmp.get(((ParameterizedBuiltinOp)hop).getParameterHop("replacement").getHopID());
            CNodeTernary.TernaryType ttype = cdata2.isLiteral() && cdata2.getVarname().equals("Double.NaN") ? CNodeTernary.TernaryType.REPLACE_NAN : CNodeTernary.TernaryType.REPLACE;
            out = new CNodeTernary(cdata1, cdata2, cdata3, ttype);
        } else if (hop instanceof IndexingOp) {
            cdata1 = tmp.get(hop.getInput().get(0).getHopID());
            out = new CNodeTernary(cdata1, TemplateUtils.createCNodeData(new LiteralOp(hop.getInput().get(0).getDim2()), true), TemplateUtils.createCNodeData(hop.getInput().get(4), true), CNodeTernary.TernaryType.LOOKUP_RC1);
        }
        if (out == null) {
            throw new RuntimeException(hop.getHopID() + " " + hop.getOpString());
        }
        if (out.getDataType().isMatrix()) {
            out.setNumRows(hop.getDim1());
            out.setNumCols(hop.getDim2());
        }
        tmp.put(hop.getHopID(), out);
    }

    public static class HopInputComparator
    implements Comparator<Hop> {
        private final Hop _X;

        public HopInputComparator(Hop X) {
            this._X = X;
        }

        @Override
        public int compare(Hop h1, Hop h2) {
            long ncells2;
            long ncells1;
            long l = h1.getDataType() == Expression.DataType.SCALAR ? Long.MIN_VALUE : (h1 == this._X ? Long.MAX_VALUE : (ncells1 = h1.dimsKnown() ? h1.getDim1() * h1.getDim2() : 0x7FFFFFFFFFFFFFFEL));
            long l2 = h2.getDataType() == Expression.DataType.SCALAR ? Long.MIN_VALUE : (h2 == this._X ? Long.MAX_VALUE : (ncells2 = h2.dimsKnown() ? h2.getDim1() * h2.getDim2() : 0x7FFFFFFFFFFFFFFEL));
            return ncells1 > ncells2 ? -1 : (ncells1 < ncells2 ? 1 : 0);
        }
    }
}

