/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.instructions.cp;

import java.util.ArrayList;
import java.util.Arrays;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.functionobjects.SwapIndex;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPInstruction;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.cp.UnaryCPInstruction;
import org.apache.sysml.runtime.matrix.data.ConvolutionParameters;
import org.apache.sysml.runtime.matrix.data.LibMatrixDNN;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.operators.ReorgOperator;
import org.apache.sysml.runtime.util.ConvolutionUtils;

public class ConvolutionCPInstruction
extends UnaryCPInstruction {
    private CPOperand _in2;
    private CPOperand _in3;
    private ArrayList<CPOperand> _input_shape;
    private ArrayList<CPOperand> _filter_shape;
    private ArrayList<CPOperand> _stride = new ArrayList();
    private ArrayList<CPOperand> _padding = new ArrayList();
    private int _numThreads = -1;

    public ConvolutionCPInstruction(CPOperand in, CPOperand in2, CPOperand out, String opcode, String istr, int numThreads) throws DMLRuntimeException {
        super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), in, out, opcode, istr);
        if (!(opcode.equals("bias_add") || opcode.equals("relu_backward") || opcode.equals("bias_multiply"))) {
            throw new DMLRuntimeException("Incorrect usage. Expected the opcode to be bias_add or bias_multiply or relu_backward, but found " + opcode);
        }
        this._in2 = in2;
        this._cptype = CPInstruction.CPINSTRUCTION_TYPE.Convolution;
        this._numThreads = numThreads;
    }

    public ConvolutionCPInstruction(CPOperand in, CPOperand out, String opcode, String istr, ArrayList<CPOperand> stride, ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape, ArrayList<CPOperand> filter_shape, int numThreads) {
        super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), in, out, opcode, istr);
        this._cptype = CPInstruction.CPINSTRUCTION_TYPE.Convolution;
        this._stride = stride;
        this._padding = padding;
        this._input_shape = input_shape;
        this._filter_shape = filter_shape;
        this._numThreads = numThreads;
    }

    public ConvolutionCPInstruction(CPOperand in, CPOperand in2, CPOperand out, String opcode, String istr, ArrayList<CPOperand> stride, ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape, ArrayList<CPOperand> filter_shape, int numThreads) {
        super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), in, out, opcode, istr);
        this._in2 = in2;
        this._cptype = CPInstruction.CPINSTRUCTION_TYPE.Convolution;
        this._stride = stride;
        this._padding = padding;
        this._input_shape = input_shape;
        this._filter_shape = filter_shape;
        this._numThreads = numThreads;
    }

    public ConvolutionCPInstruction(CPOperand in, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String istr, ArrayList<CPOperand> stride, ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape, ArrayList<CPOperand> filter_shape, int numThreads) {
        super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), in, out, opcode, istr);
        this._in2 = in2;
        this._in3 = in3;
        this._cptype = CPInstruction.CPINSTRUCTION_TYPE.Convolution;
        this._stride = stride;
        this._padding = padding;
        this._input_shape = input_shape;
        this._filter_shape = filter_shape;
        this._numThreads = numThreads;
    }

    public static ConvolutionCPInstruction parseInstruction(String str) throws DMLRuntimeException {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        String opcode = parts[0];
        if (opcode.equalsIgnoreCase("maxpooling") || opcode.equalsIgnoreCase("relu_maxpooling")) {
            InstructionUtils.checkNumFields(parts, 15);
            CPOperand in = new CPOperand(parts[1]);
            CPOperand out = new CPOperand(parts[14]);
            ArrayList<CPOperand> stride = new ArrayList<CPOperand>();
            ArrayList<CPOperand> padding = new ArrayList<CPOperand>();
            ArrayList<CPOperand> input_shape = new ArrayList<CPOperand>();
            ArrayList<CPOperand> filter_shape = new ArrayList<CPOperand>();
            stride.add(new CPOperand(parts[2]));
            stride.add(new CPOperand(parts[3]));
            padding.add(new CPOperand(parts[4]));
            padding.add(new CPOperand(parts[5]));
            input_shape.add(new CPOperand(parts[6]));
            input_shape.add(new CPOperand(parts[7]));
            input_shape.add(new CPOperand(parts[8]));
            input_shape.add(new CPOperand(parts[9]));
            filter_shape.add(new CPOperand(parts[10]));
            filter_shape.add(new CPOperand(parts[11]));
            filter_shape.add(new CPOperand(parts[12]));
            filter_shape.add(new CPOperand(parts[13]));
            int k = Integer.parseInt(parts[15]);
            return new ConvolutionCPInstruction(in, out, opcode, str, stride, padding, input_shape, filter_shape, k);
        }
        if (opcode.equalsIgnoreCase("maxpooling_backward") || opcode.equalsIgnoreCase("conv2d") || opcode.equalsIgnoreCase("conv2d_backward_filter") || opcode.equalsIgnoreCase("conv2d_backward_data")) {
            InstructionUtils.checkNumFields(parts, 16);
            CPOperand in = new CPOperand(parts[1]);
            CPOperand in2 = new CPOperand(parts[2]);
            CPOperand out = new CPOperand(parts[15]);
            ArrayList<CPOperand> stride = new ArrayList<CPOperand>();
            ArrayList<CPOperand> padding = new ArrayList<CPOperand>();
            ArrayList<CPOperand> input_shape = new ArrayList<CPOperand>();
            ArrayList<CPOperand> filter_shape = new ArrayList<CPOperand>();
            stride.add(new CPOperand(parts[3]));
            stride.add(new CPOperand(parts[4]));
            padding.add(new CPOperand(parts[5]));
            padding.add(new CPOperand(parts[6]));
            input_shape.add(new CPOperand(parts[7]));
            input_shape.add(new CPOperand(parts[8]));
            input_shape.add(new CPOperand(parts[9]));
            input_shape.add(new CPOperand(parts[10]));
            filter_shape.add(new CPOperand(parts[11]));
            filter_shape.add(new CPOperand(parts[12]));
            filter_shape.add(new CPOperand(parts[13]));
            filter_shape.add(new CPOperand(parts[14]));
            int k = Integer.parseInt(parts[16]);
            return new ConvolutionCPInstruction(in, in2, out, opcode, str, stride, padding, input_shape, filter_shape, k);
        }
        if (opcode.equalsIgnoreCase("conv2d_bias_add")) {
            InstructionUtils.checkNumFields(parts, 17);
            CPOperand in = new CPOperand(parts[1]);
            CPOperand in2 = new CPOperand(parts[2]);
            CPOperand in3 = new CPOperand(parts[3]);
            CPOperand out = new CPOperand(parts[16]);
            ArrayList<CPOperand> stride = new ArrayList<CPOperand>();
            ArrayList<CPOperand> padding = new ArrayList<CPOperand>();
            ArrayList<CPOperand> input_shape = new ArrayList<CPOperand>();
            ArrayList<CPOperand> filter_shape = new ArrayList<CPOperand>();
            stride.add(new CPOperand(parts[4]));
            stride.add(new CPOperand(parts[5]));
            padding.add(new CPOperand(parts[6]));
            padding.add(new CPOperand(parts[7]));
            input_shape.add(new CPOperand(parts[8]));
            input_shape.add(new CPOperand(parts[9]));
            input_shape.add(new CPOperand(parts[10]));
            input_shape.add(new CPOperand(parts[11]));
            filter_shape.add(new CPOperand(parts[12]));
            filter_shape.add(new CPOperand(parts[13]));
            filter_shape.add(new CPOperand(parts[14]));
            filter_shape.add(new CPOperand(parts[15]));
            int k = Integer.parseInt(parts[17]);
            return new ConvolutionCPInstruction(in, in2, in3, out, opcode, str, stride, padding, input_shape, filter_shape, k);
        }
        if (opcode.equalsIgnoreCase("bias_add") || opcode.equals("relu_backward") || opcode.equalsIgnoreCase("bias_multiply")) {
            InstructionUtils.checkNumFields(parts, 4);
            CPOperand in = new CPOperand(parts[1]);
            CPOperand in2 = new CPOperand(parts[2]);
            CPOperand out = new CPOperand(parts[3]);
            int k = Integer.parseInt(parts[4]);
            return new ConvolutionCPInstruction(in, in2, out, opcode, str, k);
        }
        throw new DMLRuntimeException("Unknown opcode while parsing a ConvolutionCPInstruction: " + str);
    }

    private int getScalarInput(ExecutionContext ec, ArrayList<CPOperand> aL, int index) throws DMLRuntimeException {
        return (int)ec.getScalarInput(aL.get(index).getName(), aL.get(index).getValueType(), aL.get(index).isLiteral()).getLongValue();
    }

    public void processReluBackwardInstruction(ExecutionContext ec) throws DMLRuntimeException {
        MatrixBlock input = ec.getMatrixInput(this.input1.getName());
        MatrixBlock dout = ec.getMatrixInput(this._in2.getName());
        MatrixBlock outputBlock = new MatrixBlock(input.getNumRows(), input.getNumColumns(), false);
        if (!input.isEmptyBlock() && !dout.isEmptyBlock()) {
            outputBlock.allocateDenseOrSparseBlock();
            LibMatrixDNN.reluBackward(input, dout, outputBlock, this._numThreads);
        }
        ec.releaseMatrixInput(this.input1.getName());
        ec.releaseMatrixInput(this._in2.getName());
        ec.setMatrixOutput(this.getOutputVariableName(), outputBlock);
    }

    public void processBiasAddInstruction(ExecutionContext ec) throws DMLRuntimeException {
        MatrixBlock input = ec.getMatrixInput(this.input1.getName());
        MatrixBlock bias = ec.getMatrixInput(this._in2.getName());
        MatrixBlock outputBlock = null;
        if (bias.getNumColumns() != 1) {
            throw new DMLRuntimeException("Expected the number of columns of bias matrix to be 1, but found " + bias.getNumColumns());
        }
        if (input.isEmptyBlock() && bias.isEmptyBlock()) {
            outputBlock = new MatrixBlock(input.getNumRows(), input.getNumColumns(), true);
        } else if (bias.isEmptyBlock()) {
            outputBlock = new MatrixBlock(input);
        } else {
            outputBlock = new MatrixBlock(input.getNumRows(), input.getNumColumns(), false);
            outputBlock.allocateDenseBlock();
            LibMatrixDNN.biasAdd(input, bias, outputBlock, this._numThreads);
        }
        ec.releaseMatrixInput(this.input1.getName());
        ec.releaseMatrixInput(this._in2.getName());
        ec.setMatrixOutput(this.getOutputVariableName(), outputBlock);
    }

    public void processBiasMultiplyInstruction(ExecutionContext ec) throws DMLRuntimeException {
        MatrixBlock input = ec.getMatrixInput(this.input1.getName());
        MatrixBlock bias = ec.getMatrixInput(this._in2.getName());
        MatrixBlock outputBlock = null;
        if (bias.getNumColumns() != 1) {
            throw new DMLRuntimeException("Expected the number of columns of bias matrix to be 1, but found " + bias.getNumColumns());
        }
        if (bias.isEmptyBlock()) {
            outputBlock = new MatrixBlock(input.getNumRows(), input.getNumColumns(), true);
        } else {
            outputBlock = new MatrixBlock(input.getNumRows(), input.getNumColumns(), false);
            outputBlock.allocateDenseBlock();
            LibMatrixDNN.biasMultiply(input, bias, outputBlock, this._numThreads);
        }
        ec.releaseMatrixInput(this.input1.getName());
        ec.releaseMatrixInput(this._in2.getName());
        ec.setMatrixOutput(this.getOutputVariableName(), outputBlock);
    }

    @Override
    public void processInstruction(ExecutionContext ec) throws DMLRuntimeException {
        if (this.instOpcode.equalsIgnoreCase("bias_add")) {
            this.processBiasAddInstruction(ec);
            return;
        }
        if (this.instOpcode.equalsIgnoreCase("bias_multiply")) {
            this.processBiasMultiplyInstruction(ec);
            return;
        }
        if (this.instOpcode.equalsIgnoreCase("relu_backward")) {
            this.processReluBackwardInstruction(ec);
            return;
        }
        MatrixBlock outputBlock = null;
        MatrixBlock matBlock = ec.getMatrixInput(this.input1.getName());
        int pad_h = this.getScalarInput(ec, this._padding, 0);
        int pad_w = this.getScalarInput(ec, this._padding, 1);
        int stride_h = this.getScalarInput(ec, this._stride, 0);
        int stride_w = this.getScalarInput(ec, this._stride, 1);
        int N = this.getScalarInput(ec, this._input_shape, 0);
        int C = this.getScalarInput(ec, this._input_shape, 1);
        int H = this.getScalarInput(ec, this._input_shape, 2);
        int W = this.getScalarInput(ec, this._input_shape, 3);
        int K = this.getScalarInput(ec, this._filter_shape, 0);
        int R = this.getScalarInput(ec, this._filter_shape, 2);
        int S = this.getScalarInput(ec, this._filter_shape, 3);
        int P = (int)ConvolutionUtils.getP(H, R, stride_h, pad_h);
        int Q = (int)ConvolutionUtils.getQ(W, S, stride_w, pad_w);
        ConvolutionParameters params = new ConvolutionParameters(N, C, H, W, K, R, S, stride_h, stride_w, pad_h, pad_w, this._numThreads);
        if (this.instOpcode.equalsIgnoreCase("maxpooling") || this.instOpcode.equalsIgnoreCase("relu_maxpooling")) {
            if (matBlock.isEmptyBlock()) {
                outputBlock = new MatrixBlock(N, C * P * Q, true);
            } else {
                outputBlock = this.getDenseOutputBlock(N, C * P * Q);
                if (this.instOpcode.equalsIgnoreCase("maxpooling")) {
                    Arrays.fill(outputBlock.getDenseBlock(), -1.7976931348623157E308);
                }
                LibMatrixDNN.maxpooling(matBlock, outputBlock, params);
            }
        } else if (this.instOpcode.equalsIgnoreCase("maxpooling_backward")) {
            MatrixBlock dout = ec.getMatrixInput(this._in2.getName());
            if (matBlock.isEmptyBlock() || dout.isEmptyBlock()) {
                outputBlock = new MatrixBlock(N, C * H * W, true);
            } else {
                outputBlock = this.getDenseOutputBlock(N, C * H * W);
                LibMatrixDNN.maxpoolingBackward(matBlock, dout, outputBlock, params);
            }
            ec.releaseMatrixInput(this._in2.getName());
        } else if (this.instOpcode.equalsIgnoreCase("conv2d")) {
            MatrixBlock filter = ec.getMatrixInput(this._in2.getName());
            if (filter.isEmptyBlock() || matBlock.isEmptyBlock()) {
                outputBlock = new MatrixBlock(N, K * P * Q, true);
            } else {
                outputBlock = this.getDenseOutputBlock(N, K * P * Q);
                LibMatrixDNN.conv2d(matBlock, filter, outputBlock, params);
            }
            ec.releaseMatrixInput(this._in2.getName());
        } else if (this.instOpcode.equalsIgnoreCase("conv2d_bias_add")) {
            MatrixBlock filter = ec.getMatrixInput(this._in3.getName());
            MatrixBlock bias = ec.getMatrixInput(this._in2.getName());
            if ((filter.isEmptyBlock() || matBlock.isEmptyBlock()) && bias.isEmptyBlock()) {
                outputBlock = new MatrixBlock(N, K * P * Q, true);
            } else {
                outputBlock = this.getDenseOutputBlock(N, K * P * Q);
                if (!bias.isEmptyBlock()) {
                    params.bias = bias;
                }
                LibMatrixDNN.conv2d(matBlock, filter, outputBlock, params);
            }
            ec.releaseMatrixInput(this._in3.getName());
            ec.releaseMatrixInput(this._in2.getName());
        } else if (this.instOpcode.equalsIgnoreCase("conv2d_backward_filter")) {
            MatrixBlock dout = ec.getMatrixInput(this._in2.getName());
            if (dout.isEmptyBlock() || matBlock.isEmptyBlock()) {
                outputBlock = new MatrixBlock(K, C * R * S, true);
            } else {
                outputBlock = this.getDenseOutputBlock(K, C * R * S);
                LibMatrixDNN.conv2dBackwardFilter(matBlock, dout, outputBlock, params);
            }
            ec.releaseMatrixInput(this._in2.getName());
        } else if (this.instOpcode.equalsIgnoreCase("conv2d_backward_data")) {
            MatrixBlock dout = ec.getMatrixInput(this._in2.getName());
            if (dout.isEmptyBlock() || matBlock.isEmptyBlock()) {
                outputBlock = new MatrixBlock(N, C * H * W, true);
            } else {
                outputBlock = this.getDenseOutputBlock(N, C * H * W);
                LibMatrixDNN.conv2dBackwardData(matBlock, dout, outputBlock, params);
            }
            ec.releaseMatrixInput(this._in2.getName());
        } else {
            throw new DMLRuntimeException("Unsupported op code " + this.instOpcode);
        }
        ec.releaseMatrixInput(this.input1.getName());
        ec.setMatrixOutput(this.getOutputVariableName(), outputBlock);
    }

    private MatrixBlock getDenseOutputBlock(int numRows, int numCols) throws DMLRuntimeException {
        MatrixBlock outputBlock = new MatrixBlock(numRows, numCols, false);
        outputBlock.allocateDenseBlock();
        return outputBlock;
    }
}

