/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.hops;

import java.util.ArrayList;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.MemoTable;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.hops.UnaryOp;
import org.apache.sysml.lops.ConvolutionTransform;
import org.apache.sysml.lops.Lop;
import org.apache.sysml.lops.LopProperties;
import org.apache.sysml.lops.LopsException;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.ConvolutionParameters;

public class ConvolutionOp
extends Hop
implements Hop.MultiThreadedHop {
    private Hop.ConvOp op;
    private int _maxNumThreads = -1;

    private ConvolutionOp() {
    }

    public ConvolutionOp(String l, Expression.DataType dt, Expression.ValueType vt, Hop.ConvOp o, ArrayList<Hop> inp) {
        super(l, dt, vt);
        this.op = o;
        for (int i = 0; i < inp.size(); ++i) {
            Hop in = inp.get(i);
            this.getInput().add(i, in);
            in.getParent().add(this);
        }
        this.refreshSizeInformation();
    }

    public Hop.ConvOp getOp() {
        return this.op;
    }

    @Override
    public String getOpString() {
        return "" + HopsConv2Lops.get((Object)this.op);
    }

    private boolean isEligibleForSpark() {
        return false;
    }

    @Override
    public Lop constructLops() throws HopsException, LopsException {
        if (this.getLops() != null) {
            return this.getLops();
        }
        LopProperties.ExecType et = this.optFindExecType();
        ArrayList<Hop> inputs = this.getInput();
        switch (this.op) {
            case MAX_POOLING: 
            case MAX_POOLING_BACKWARD: 
            case DIRECT_CONV2D: 
            case DIRECT_CONV2D_BACKWARD_DATA: 
            case DIRECT_CONV2D_BACKWARD_FILTER: 
            case BIAS_ADD: 
            case BIAS_MULTIPLY: {
                if (et == LopProperties.ExecType.CP || et == LopProperties.ExecType.GPU) {
                    this.setLops(this.constructConvolutionLops(et, inputs));
                    break;
                }
                throw new HopsException("Unimplemented ConvolutionOp for execution type: " + et.name());
            }
            default: {
                throw new HopsException("Unsupported lops construction for operation type '" + (Object)((Object)this.op) + "'.");
            }
        }
        this.constructAndSetLopsDataFlowProperties();
        return this.getLops();
    }

    public void setOp(Hop.ConvOp op) {
        this.op = op;
    }

    private int getNumExpectedInputs() {
        switch (this.op) {
            case MAX_POOLING_BACKWARD: 
            case DIRECT_CONV2D: 
            case DIRECT_CONV2D_BACKWARD_DATA: 
            case DIRECT_CONV2D_BACKWARD_FILTER: {
                return 14;
            }
            case BIAS_ADD: 
            case BIAS_MULTIPLY: {
                return 2;
            }
        }
        return 13;
    }

    private boolean isInputReLU(Hop input) {
        return input instanceof UnaryOp && ((UnaryOp)input).getOp() == Hop.OpOp1.SELP;
    }

    private boolean isInputConv2d(Hop input) {
        return input instanceof ConvolutionOp && ((ConvolutionOp)input).getOp() == Hop.ConvOp.DIRECT_CONV2D;
    }

    public Lop constructConvolutionLops(LopProperties.ExecType et, ArrayList<Hop> inputs) throws HopsException, LopsException {
        if (inputs.size() != this.getNumExpectedInputs()) {
            throw new HopsException("Incorrect number of inputs for " + this.op.name());
        }
        Lop in = null;
        Lop in2 = null;
        ArrayList<Hop> inputs1 = inputs;
        int k = OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads);
        ConvolutionTransform.OperationTypes lopOp = (ConvolutionTransform.OperationTypes)((Object)HopsConv2Lops.get((Object)this.op));
        if (this.op == Hop.ConvOp.MAX_POOLING && this.isInputReLU(inputs.get(0))) {
            in = inputs.get(0).getInput().get(0).constructLops();
            lopOp = ConvolutionTransform.OperationTypes.RELU_MAX_POOLING;
        } else if (this.op == Hop.ConvOp.BIAS_ADD && this.isInputConv2d(inputs.get(0))) {
            lopOp = ConvolutionTransform.OperationTypes.DIRECT_CONV2D_BIAS_ADD;
            in = inputs.get(0).getInput().get(0).constructLops();
            in2 = inputs.get(1).constructLops();
            inputs1 = inputs.get(0).getInput();
        } else {
            in = inputs.get(0).constructLops();
        }
        ConvolutionTransform transform1 = new ConvolutionTransform(in, lopOp, this.getDataType(), this.getValueType(), et, k);
        this.setOutputDimensions(transform1);
        this.setLineNumbers(transform1);
        in.addOutput(transform1);
        if (in2 != null) {
            transform1.addInput(in2);
            in2.addOutput(transform1);
        }
        for (int i = 1; i < inputs1.size(); ++i) {
            Lop ltmp = inputs1.get(i).constructLops();
            transform1.addInput(ltmp);
            ltmp.addOutput(transform1);
        }
        transform1.setLevel();
        return transform1;
    }

    @Override
    protected double computeOutputMemEstimate(long dim1, long dim2, long nnz) {
        double sparsity = 1.0;
        return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity);
    }

    @Override
    protected double computeIntermediateMemEstimate(long dim1, long dim2, long nnz) {
        return 0.0;
    }

    @Override
    protected long[] inferOutputCharacteristics(MemoTable memo) {
        ConvolutionParameters params;
        long[] ret = new long[3];
        if (this.op == Hop.ConvOp.BIAS_ADD || this.op == Hop.ConvOp.BIAS_MULTIPLY) {
            MatrixCharacteristics[] mc = memo.getAllInputStats(this.getInput());
            ret[0] = mc[0].rowsKnown() ? mc[0].getRows() : -1L;
            ret[1] = mc[0].colsKnown() ? mc[0].getCols() : -1L;
            ret[2] = -1L;
            return ret[0] > 0L && ret[1] > 0L ? ret : null;
        }
        try {
            params = this.parseInput();
        }
        catch (DMLRuntimeException e) {
            throw new RuntimeException(e);
        }
        switch (this.op) {
            case MAX_POOLING: {
                ret[0] = this.getInput().get((int)0)._dim1;
                ret[1] = ConvolutionOp.getExtractedVal(params.C, params.P, params.Q);
                ret[2] = -1L;
                break;
            }
            case DIRECT_CONV2D: {
                ret[0] = this.getInput().get((int)0)._dim1;
                ret[1] = ConvolutionOp.getExtractedVal(this.getInput().get((int)1)._dim1, params.P, params.Q);
                ret[2] = -1L;
                break;
            }
            case DIRECT_CONV2D_BACKWARD_FILTER: {
                ret[0] = this.getInput().get((int)1)._dim1;
                ret[1] = this.getInput().get((int)1)._dim2;
                ret[2] = -1L;
                break;
            }
            case MAX_POOLING_BACKWARD: 
            case DIRECT_CONV2D_BACKWARD_DATA: {
                ret[0] = this.getInput().get((int)0)._dim1;
                ret[1] = this.getInput().get((int)0)._dim2;
                ret[2] = -1L;
                break;
            }
            default: {
                throw new RuntimeException("Unsupported op:" + this.op.name());
            }
        }
        if (LOG.isDebugEnabled() && (ret[0] <= 0L || ret[1] <= 0L)) {
            LOG.debug((Object)("Unknown dimensions for ConvolutionOp in inferOutputCharacteristics:" + this.op.name() + " " + ret[0] + " " + ret[1] + " img_dim=[" + params.N + " " + params.C + " " + params.H + " " + params.W + "] filter_dim=[" + params.K + " " + params.C + " " + params.H + " " + params.W + "] output_feature_map=[" + params.P + " " + params.Q + "] stride=[" + params.stride_h + " " + params.stride_w + "] pad=[" + params.pad_h + " " + params.pad_w + "]"));
        }
        return ret[0] > 0L && ret[1] > 0L ? ret : null;
    }

    @Override
    public boolean allowsAllExecTypes() {
        return true;
    }

    @Override
    protected LopProperties.ExecType optFindExecType() throws HopsException {
        LopProperties.ExecType REMOTE;
        this.checkAndSetForcedPlatform();
        LopProperties.ExecType execType = REMOTE = OptimizerUtils.isSparkExecutionMode() ? LopProperties.ExecType.SPARK : LopProperties.ExecType.MR;
        if (this._etypeForced != null) {
            this._etype = this.findGPUExecTypeByMemEstimate(this._etypeForced);
        } else {
            this._etype = OptimizerUtils.isMemoryBasedOptLevel() ? this.findGPUExecTypeByMemEstimate(this.findExecTypeByMemEstimate()) : REMOTE;
            this.checkAndSetInvalidCPDimsAndSize();
        }
        LopProperties.ExecType execType2 = this._etype = !this.isEligibleForSpark() && this._etype == REMOTE ? LopProperties.ExecType.CP : this._etype;
        if (ConfigurationManager.isDynamicRecompilation() && !this.dimsKnown(true) && this._etype == REMOTE) {
            this.setRequiresRecompile();
        }
        return this._etype;
    }

    ConvolutionParameters parseInput() throws DMLRuntimeException {
        ConvolutionParameters params = null;
        params = this.op == Hop.ConvOp.MAX_POOLING_BACKWARD || this.op == Hop.ConvOp.DIRECT_CONV2D || this.op == Hop.ConvOp.DIRECT_CONV2D_BACKWARD_FILTER || this.op == Hop.ConvOp.DIRECT_CONV2D_BACKWARD_DATA ? new ConvolutionParameters(this.computeSizeInformation(this.getInput().get(6)), this.computeSizeInformation(this.getInput().get(7)), this.computeSizeInformation(this.getInput().get(8)), this.computeSizeInformation(this.getInput().get(9)), this.computeSizeInformation(this.getInput().get(10)), this.computeSizeInformation(this.getInput().get(12)), this.computeSizeInformation(this.getInput().get(13)), this.computeSizeInformation(this.getInput().get(2)), this.computeSizeInformation(this.getInput().get(3)), this.computeSizeInformation(this.getInput().get(4)), this.computeSizeInformation(this.getInput().get(5)), this._maxNumThreads) : new ConvolutionParameters(this.computeSizeInformation(this.getInput().get(5)), this.computeSizeInformation(this.getInput().get(6)), this.computeSizeInformation(this.getInput().get(7)), this.computeSizeInformation(this.getInput().get(8)), this.computeSizeInformation(this.getInput().get(9)), this.computeSizeInformation(this.getInput().get(11)), this.computeSizeInformation(this.getInput().get(12)), this.computeSizeInformation(this.getInput().get(1)), this.computeSizeInformation(this.getInput().get(2)), this.computeSizeInformation(this.getInput().get(3)), this.computeSizeInformation(this.getInput().get(4)), this._maxNumThreads);
        return params;
    }

    public static long getExtractedVal(long val1, long val2, long val3) {
        if (val1 == -1L || val2 == -1L || val3 == -1L) {
            return -1L;
        }
        return val1 * val2 * val3;
    }

    @Override
    public void refreshSizeInformation() {
        ConvolutionParameters params;
        if (this.op == Hop.ConvOp.BIAS_ADD || this.op == Hop.ConvOp.BIAS_MULTIPLY) {
            Hop input1 = this.getInput().get(0);
            this.setDim1(input1.getDim1());
            this.setDim2(input1.getDim2());
            return;
        }
        try {
            params = this.parseInput();
        }
        catch (DMLRuntimeException e) {
            throw new RuntimeException(e);
        }
        switch (this.op) {
            case MAX_POOLING: {
                this._dim1 = this.getInput().get((int)0)._dim1;
                this._dim2 = ConvolutionOp.getExtractedVal(params.C, params.P, params.Q);
                this._nnz = -1L;
                break;
            }
            case MAX_POOLING_BACKWARD: {
                this._dim1 = this.getInput().get((int)0)._dim1;
                this._dim2 = this.getInput().get((int)0)._dim2;
                this._nnz = -1L;
                break;
            }
            case DIRECT_CONV2D: {
                this._dim1 = this.getInput().get((int)0)._dim1;
                this._dim2 = ConvolutionOp.getExtractedVal(this.getInput().get((int)1)._dim1, params.P, params.Q);
                this._nnz = -1L;
                break;
            }
            case DIRECT_CONV2D_BACKWARD_DATA: {
                this._dim1 = this.getInput().get((int)0)._dim1;
                this._dim2 = this.getInput().get((int)0)._dim2;
                this._nnz = -1L;
                break;
            }
            case DIRECT_CONV2D_BACKWARD_FILTER: {
                this._dim1 = this.getInput().get((int)1)._dim1;
                this._dim2 = this.getInput().get((int)1)._dim2;
                this._nnz = -1L;
                break;
            }
            default: {
                throw new RuntimeException("The sizes are not refreshed for " + this.op.name());
            }
        }
        if (LOG.isDebugEnabled() && (this._dim1 <= 0L || this._dim2 <= 0L)) {
            LOG.debug((Object)("Unknown dimensions for ConvolutionOp in refreshSizeInformation:" + this.op.name() + " " + this._dim1 + " " + this._dim2 + " img_dim=[" + params.N + " " + params.C + " " + params.H + " " + params.W + "] filter_dim=[" + params.K + " " + params.C + " " + params.H + " " + params.W + "] output_feature_map=[" + params.P + " " + params.Q + "] stride=[" + params.stride_h + " " + params.stride_w + "] pad=[" + params.pad_h + " " + params.pad_w + "]"));
        }
    }

    @Override
    public Object clone() throws CloneNotSupportedException {
        ConvolutionOp ret = new ConvolutionOp();
        ret.clone(this, false);
        ret.op = this.op;
        ret._maxNumThreads = this._maxNumThreads;
        return ret;
    }

    @Override
    public boolean compare(Hop that) {
        boolean ret;
        if (!(that instanceof ConvolutionOp)) {
            return false;
        }
        ConvolutionOp that2 = (ConvolutionOp)that;
        boolean bl = ret = this.op == that2.op && this.getInput().size() == that.getInput().size() && this._maxNumThreads == that2._maxNumThreads;
        if (ret) {
            for (int i = 0; i < this._input.size(); ++i) {
                ret &= this.getInput().get(i) == that2.getInput().get(i);
            }
        }
        return ret;
    }

    @Override
    public void setMaxNumThreads(int k) {
        this._maxNumThreads = k;
    }

    @Override
    public int getMaxNumThreads() {
        return this._maxNumThreads;
    }
}

