/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.hops.ipa;

import java.util.ArrayList;
import org.apache.sysml.hops.FunctionOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.LiteralOp;
import org.apache.sysml.hops.ipa.FunctionCallGraph;
import org.apache.sysml.hops.ipa.FunctionCallSizeInfo;
import org.apache.sysml.hops.ipa.IPAPass;
import org.apache.sysml.hops.recompile.Recompiler;
import org.apache.sysml.parser.DMLProgram;
import org.apache.sysml.parser.DataIdentifier;
import org.apache.sysml.parser.ForStatement;
import org.apache.sysml.parser.ForStatementBlock;
import org.apache.sysml.parser.FunctionStatement;
import org.apache.sysml.parser.FunctionStatementBlock;
import org.apache.sysml.parser.IfStatement;
import org.apache.sysml.parser.IfStatementBlock;
import org.apache.sysml.parser.StatementBlock;
import org.apache.sysml.parser.WhileStatement;
import org.apache.sysml.parser.WhileStatementBlock;
import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
import org.apache.sysml.runtime.instructions.cp.ScalarObjectFactory;

public class IPAPassPropagateReplaceLiterals
extends IPAPass {
    @Override
    public boolean isApplicable(FunctionCallGraph fgraph) {
        return true;
    }

    @Override
    public void rewriteProgram(DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes) throws HopsException {
        for (String fkey : fgraph.getReachableFunctions()) {
            FunctionOp first = fgraph.getFunctionCalls(fkey).get(0);
            if (!fcallSizes.hasSafeLiterals(fkey)) continue;
            FunctionStatementBlock fsb = prog.getFunctionStatementBlock(fkey);
            FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0);
            ArrayList<DataIdentifier> finputs = fstmt.getInputParams();
            LocalVariableMap callVars = new LocalVariableMap();
            for (int j = 0; j < finputs.size(); ++j) {
                if (!fcallSizes.isSafeLiteral(fkey, j)) continue;
                LiteralOp lit = (LiteralOp)first.getInput().get(j);
                callVars.put(finputs.get(j).getName(), ScalarObjectFactory.createScalarObject(lit.getValueType(), lit));
            }
            for (StatementBlock sb : fstmt.getBody()) {
                this.rReplaceLiterals(sb, callVars);
            }
        }
    }

    private void rReplaceLiterals(StatementBlock sb, LocalVariableMap constants) throws HopsException {
        for (String varname : sb.variablesUpdated().getVariableNames()) {
            if (!constants.keySet().contains(varname)) continue;
            constants.remove(varname);
        }
        if (sb instanceof WhileStatementBlock) {
            WhileStatementBlock wsb = (WhileStatementBlock)sb;
            WhileStatement ws = (WhileStatement)sb.getStatement(0);
            IPAPassPropagateReplaceLiterals.replaceLiterals(wsb.getPredicateHops(), constants);
            for (StatementBlock current : ws.getBody()) {
                this.rReplaceLiterals(current, constants);
            }
        } else if (sb instanceof IfStatementBlock) {
            IfStatementBlock isb = (IfStatementBlock)sb;
            IfStatement ifs = (IfStatement)sb.getStatement(0);
            IPAPassPropagateReplaceLiterals.replaceLiterals(isb.getPredicateHops(), constants);
            for (StatementBlock current : ifs.getIfBody()) {
                this.rReplaceLiterals(current, constants);
            }
            for (StatementBlock current : ifs.getElseBody()) {
                this.rReplaceLiterals(current, constants);
            }
        } else if (sb instanceof ForStatementBlock) {
            ForStatementBlock fsb = (ForStatementBlock)sb;
            ForStatement fs = (ForStatement)sb.getStatement(0);
            IPAPassPropagateReplaceLiterals.replaceLiterals(fsb.getFromHops(), constants);
            IPAPassPropagateReplaceLiterals.replaceLiterals(fsb.getToHops(), constants);
            IPAPassPropagateReplaceLiterals.replaceLiterals(fsb.getIncrementHops(), constants);
            for (StatementBlock current : fs.getBody()) {
                this.rReplaceLiterals(current, constants);
            }
        } else {
            IPAPassPropagateReplaceLiterals.replaceLiterals(sb.getHops(), constants);
        }
    }

    private static void replaceLiterals(ArrayList<Hop> roots, LocalVariableMap constants) throws HopsException {
        if (roots == null) {
            return;
        }
        try {
            Hop.resetVisitStatus(roots);
            for (Hop root : roots) {
                Recompiler.rReplaceLiterals(root, constants, true);
            }
            Hop.resetVisitStatus(roots);
        }
        catch (Exception ex) {
            throw new HopsException(ex);
        }
    }

    private static void replaceLiterals(Hop root, LocalVariableMap constants) throws HopsException {
        if (root == null) {
            return;
        }
        try {
            root.resetVisitStatus();
            Recompiler.rReplaceLiterals(root, constants, true);
            root.resetVisitStatus();
        }
        catch (Exception ex) {
            throw new HopsException(ex);
        }
    }
}

