/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.hops.ipa;

import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.sysml.hops.FunctionOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.LiteralOp;
import org.apache.sysml.hops.ipa.FunctionCallGraph;
import org.apache.sysml.hops.rewrite.HopRewriteUtils;

public class FunctionCallSizeInfo {
    private final FunctionCallGraph _fgraph;
    private final Set<String> _fcand;
    private final Set<String> _fcandUnary;
    private final Map<String, Set<Integer>> _fcandSafeNNZ;
    private final Map<String, Set<Integer>> _fSafeLiterals;

    public FunctionCallSizeInfo(FunctionCallGraph fgraph) throws HopsException {
        this(fgraph, true);
    }

    public FunctionCallSizeInfo(FunctionCallGraph fgraph, boolean init) throws HopsException {
        this._fgraph = fgraph;
        this._fcand = new HashSet<String>();
        this._fcandUnary = new HashSet<String>();
        this._fcandSafeNNZ = new HashMap<String, Set<Integer>>();
        this._fSafeLiterals = new HashMap<String, Set<Integer>>();
        this.constructFunctionCallSizeInfo();
    }

    public int getFunctionCallCount(String fkey) {
        return this._fgraph.getFunctionCalls(fkey).size();
    }

    public boolean isValidFunction(String fkey) {
        return this._fcand.contains(fkey);
    }

    public Set<String> getValidFunctions() {
        return this._fcand;
    }

    public Set<String> getInvalidFunctions() {
        return this._fgraph.getReachableFunctions(this.getValidFunctions());
    }

    public void addDimsPreservingFunction(String fkey) {
        this._fcandUnary.add(fkey);
    }

    public Set<String> getDimsPreservingFunctions() {
        return this._fcandUnary;
    }

    public boolean isDimsPreservingFunction(String fkey) {
        return this._fcandUnary.contains(fkey);
    }

    public boolean isSafeNnz(String fkey, int pos) {
        return this._fcandSafeNNZ.containsKey(fkey) && this._fcandSafeNNZ.get(fkey).contains(pos);
    }

    public boolean hasSafeLiterals(String fkey) {
        return this._fSafeLiterals.containsKey(fkey) && !this._fSafeLiterals.get(fkey).isEmpty();
    }

    public boolean isSafeLiteral(String fkey, int pos) {
        return this._fSafeLiterals.containsKey(fkey) && this._fSafeLiterals.get(fkey).contains(pos);
    }

    private void constructFunctionCallSizeInfo() throws HopsException {
        Hop h2;
        Hop h1;
        int j;
        FunctionOp other;
        FunctionOp first;
        List<FunctionOp> flist;
        for (String fkey : this._fgraph.getReachableFunctions()) {
            flist = this._fgraph.getFunctionCalls(fkey);
            if (flist.size() == 1) {
                this._fcand.add(fkey);
                continue;
            }
            first = flist.get(0);
            boolean consistent = true;
            for (int i = 1; i < flist.size(); ++i) {
                other = flist.get(i);
                for (j = 0; j < first.getInput().size(); ++j) {
                    h1 = first.getInput().get(j);
                    h2 = other.getInput().get(j);
                    consistent &= h1.dimsKnown() && h2.dimsKnown() && h1.getDim1() == h2.getDim1() && h1.getDim2() == h2.getDim2() && h1.getNnz() == h2.getNnz();
                    if (!(h1 instanceof LiteralOp)) continue;
                    consistent &= h2 instanceof LiteralOp && HopRewriteUtils.isEqualValue((LiteralOp)h1, (LiteralOp)h2);
                }
            }
            if (!consistent) continue;
            this._fcand.add(fkey);
        }
        for (String fkey : this._fcand) {
            FunctionOp first2 = this._fgraph.getFunctionCalls(fkey).get(0);
            HashSet<Integer> tmp = new HashSet<Integer>();
            for (int j2 = 0; j2 < first2.getInput().size(); ++j2) {
                Hop input = first2.getInput().get(0);
                if (input.getNnz() < 0L) continue;
                tmp.add(j2);
            }
            this._fcandSafeNNZ.put(fkey, tmp);
        }
        for (String fkey : this._fgraph.getReachableFunctions()) {
            flist = this._fgraph.getFunctionCalls(fkey);
            first = flist.get(0);
            HashSet<Integer> tmp = new HashSet<Integer>();
            for (int j3 = 0; j3 < first.getInput().size(); ++j3) {
                if (!(first.getInput().get(j3) instanceof LiteralOp)) continue;
                tmp.add(j3);
            }
            for (int i = 1; i < flist.size(); ++i) {
                other = flist.get(i);
                for (j = 0; j < first.getInput().size(); ++j) {
                    if (!tmp.contains(j)) continue;
                    h1 = first.getInput().get(j);
                    h2 = other.getInput().get(j);
                    if (h2 instanceof LiteralOp && HopRewriteUtils.isEqualValue((LiteralOp)h1, (LiteralOp)h2)) continue;
                    tmp.remove(j);
                }
            }
            this._fSafeLiterals.put(fkey, tmp);
        }
    }

    public int hashCode() {
        return Arrays.hashCode(new int[]{this._fgraph.hashCode(), this._fcand.hashCode(), this._fcandUnary.hashCode(), this._fcandSafeNNZ.hashCode(), this._fSafeLiterals.hashCode()});
    }

    public boolean equals(Object o) {
        if (o instanceof FunctionCallSizeInfo) {
            return false;
        }
        FunctionCallSizeInfo that = (FunctionCallSizeInfo)o;
        return this._fgraph == that._fgraph && this._fcand.equals(that._fcand) && this._fcandUnary.equals(that._fcandUnary) && this._fcandSafeNNZ.entrySet().equals(that._fcandSafeNNZ.entrySet()) && this._fSafeLiterals.entrySet().equals(that._fSafeLiterals.entrySet());
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("Valid functions for propagation: \n");
        for (String string : this.getValidFunctions()) {
            sb.append("--");
            sb.append(string);
            sb.append(": ");
            sb.append(this.getFunctionCallCount(string));
            if (!this._fcandSafeNNZ.get(string).isEmpty()) {
                sb.append("\n----");
                sb.append(Arrays.toString((Object[])this._fcandSafeNNZ.get(string).toArray(new Integer[0])));
            }
            sb.append("\n");
        }
        if (!this.getInvalidFunctions().isEmpty()) {
            sb.append("Invaid functions for propagation: \n");
            for (String string : this.getInvalidFunctions()) {
                sb.append("--");
                sb.append(string);
                sb.append(": ");
                sb.append(this.getFunctionCallCount(string));
                sb.append("\n");
            }
        }
        if (!this.getDimsPreservingFunctions().isEmpty()) {
            sb.append("Dimensions-preserving functions: \n");
            for (String string : this.getDimsPreservingFunctions()) {
                sb.append("--");
                sb.append(string);
                sb.append(": ");
                sb.append(this.getFunctionCallCount(string));
                sb.append("\n");
            }
        }
        sb.append("Valid scalars for propagation: \n");
        for (Map.Entry entry : this._fSafeLiterals.entrySet()) {
            sb.append("--");
            sb.append((String)entry.getKey());
            sb.append(": ");
            for (Integer pos : (Set)entry.getValue()) {
                sb.append(pos);
                sb.append(":");
                sb.append(this._fgraph.getFunctionCalls((String)entry.getKey()).get(0).getInput().get(pos).getName());
                sb.append(" ");
            }
            sb.append("\n");
        }
        sb.append("Valid #non-zeros for propagation: \n");
        for (Map.Entry entry : this._fcandSafeNNZ.entrySet()) {
            sb.append("--");
            sb.append((String)entry.getKey());
            sb.append(": ");
            for (Integer pos : (Set)entry.getValue()) {
                sb.append(pos);
                sb.append(":");
                sb.append(this._fgraph.getFunctionCalls((String)entry.getKey()).get(0).getInput().get(pos).getName());
                sb.append(" ");
            }
            sb.append("\n");
        }
        return sb.toString();
    }
}

