/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.api;

import java.io.IOException;
import java.util.List;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.encoders.RowEncoder;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.apache.spark.sql.execution.QueryExecution;
import org.apache.spark.sql.types.StructType;
import org.apache.sysml.api.DMLException;
import org.apache.sysml.api.MLBlock;
import org.apache.sysml.api.MLContext;
import org.apache.sysml.api.MLOutput;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.instructions.spark.functions.GetMIMBFromRow;
import org.apache.sysml.runtime.instructions.spark.functions.GetMLBlock;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import scala.Tuple2;

@Deprecated
public class MLMatrix
extends Dataset<Row> {
    private static final long serialVersionUID = -7005940673916671165L;
    protected MatrixCharacteristics mc = null;
    protected MLContext ml = null;
    static String writeStmt = "write(output, \"tmp\", format=\"binary\", rows_in_block=1000, cols_in_block=1000);";

    protected MLMatrix(SparkSession sparkSession, LogicalPlan logicalPlan, MLContext ml) {
        super(sparkSession, logicalPlan, (Encoder)RowEncoder.apply(null));
        this.ml = ml;
    }

    protected MLMatrix(SQLContext sqlContext, LogicalPlan logicalPlan, MLContext ml) {
        super(sqlContext, logicalPlan, (Encoder)RowEncoder.apply(null));
        this.ml = ml;
    }

    protected MLMatrix(SparkSession sparkSession, QueryExecution queryExecution, MLContext ml) {
        super(sparkSession, queryExecution, (Encoder)RowEncoder.apply(null));
        this.ml = ml;
    }

    protected MLMatrix(SQLContext sqlContext, QueryExecution queryExecution, MLContext ml) {
        super(sqlContext.sparkSession(), queryExecution, (Encoder)RowEncoder.apply(null));
        this.ml = ml;
    }

    protected MLMatrix(Dataset<Row> df, MatrixCharacteristics mc, MLContext ml) throws DMLRuntimeException {
        super(df.sparkSession(), df.logicalPlan(), (Encoder)RowEncoder.apply(null));
        this.mc = mc;
        this.ml = ml;
    }

    static MLMatrix createMLMatrix(MLContext ml, SparkSession sparkSession, JavaPairRDD<MatrixIndexes, MatrixBlock> blocks, MatrixCharacteristics mc) throws DMLRuntimeException {
        RDD rows = blocks.map((Function)new GetMLBlock()).rdd();
        StructType schema = MLBlock.getDefaultSchemaForBinaryBlock();
        return new MLMatrix((Dataset<Row>)sparkSession.createDataFrame(rows.toJavaRDD(), schema), mc, ml);
    }

    static MLMatrix createMLMatrix(MLContext ml, SQLContext sqlContext, JavaPairRDD<MatrixIndexes, MatrixBlock> blocks, MatrixCharacteristics mc) throws DMLRuntimeException {
        SparkSession sparkSession = sqlContext.sparkSession();
        return MLMatrix.createMLMatrix(ml, sparkSession, blocks, mc);
    }

    public void write(String filePath, String format) throws IOException, DMLException {
        this.ml.reset();
        this.ml.registerInput("left", this);
        this.ml.executeScript("left = read(\"\"); output=left; write(output, \"" + filePath + "\", format=\"" + format + "\");");
    }

    private double getScalarBuiltinFunctionResult(String fn) throws IOException, DMLException {
        if (fn.equals("nrow") || fn.equals("ncol")) {
            this.ml.reset();
            this.ml.registerInput("left", MLMatrix.getRDDLazily(this), this.mc.getRows(), this.mc.getCols(), this.mc.getRowsPerBlock(), this.mc.getColsPerBlock(), this.mc.getNonZeros());
            this.ml.registerOutput("output");
            String script = "left = read(\"\");val = " + fn + "(left); output = matrix(val, rows=1, cols=1); " + writeStmt;
            MLOutput out = this.ml.executeScript(script);
            List result = out.getBinaryBlockedRDD("output").collect();
            if (result == null || result.size() != 1) {
                throw new DMLRuntimeException("Error while computing the function: " + fn);
            }
            return ((MatrixBlock)((Tuple2)result.get((int)0))._2).getValue(0, 0);
        }
        throw new DMLRuntimeException("The function " + fn + " is not yet supported in MLMatrix");
    }

    public long numRows() throws IOException, DMLException {
        if (this.mc.rowsKnown()) {
            return this.mc.getRows();
        }
        return (long)this.getScalarBuiltinFunctionResult("nrow");
    }

    public long numCols() throws IOException, DMLException {
        if (this.mc.colsKnown()) {
            return this.mc.getCols();
        }
        return (long)this.getScalarBuiltinFunctionResult("ncol");
    }

    public int rowsPerBlock() {
        return this.mc.getRowsPerBlock();
    }

    public int colsPerBlock() {
        return this.mc.getColsPerBlock();
    }

    private String getScript(String binaryOperator) {
        return "left = read(\"\");right = read(\"\");output = left " + binaryOperator + " right; " + writeStmt;
    }

    private String getScalarBinaryScript(String binaryOperator, double scalar, boolean isScalarLeft) {
        if (isScalarLeft) {
            return "left = read(\"\");output = " + scalar + " " + binaryOperator + " left ;" + writeStmt;
        }
        return "left = read(\"\");output = left " + binaryOperator + " " + scalar + ";" + writeStmt;
    }

    static JavaPairRDD<MatrixIndexes, MatrixBlock> getRDDLazily(MLMatrix mat) {
        return mat.rdd().toJavaRDD().mapToPair((PairFunction)new GetMIMBFromRow());
    }

    private MLMatrix matrixBinaryOp(MLMatrix that, String op) throws IOException, DMLException {
        if (this.mc.getRowsPerBlock() != that.mc.getRowsPerBlock() || this.mc.getColsPerBlock() != that.mc.getColsPerBlock()) {
            throw new DMLRuntimeException("Incompatible block sizes: brlen:" + this.mc.getRowsPerBlock() + "!=" + that.mc.getRowsPerBlock() + " || bclen:" + this.mc.getColsPerBlock() + "!=" + that.mc.getColsPerBlock());
        }
        if (op.equals("%*%")) {
            if (this.mc.getCols() != that.mc.getRows()) {
                throw new DMLRuntimeException("Dimensions mismatch:" + this.mc.getCols() + "!=" + that.mc.getRows());
            }
        } else if (this.mc.getRows() != that.mc.getRows() || this.mc.getCols() != that.mc.getCols()) {
            throw new DMLRuntimeException("Dimensions mismatch:" + this.mc.getRows() + "!=" + that.mc.getRows() + " || " + this.mc.getCols() + "!=" + that.mc.getCols());
        }
        this.ml.reset();
        this.ml.registerInput("left", this);
        this.ml.registerInput("right", that);
        this.ml.registerOutput("output");
        MLOutput out = this.ml.executeScript(this.getScript(op));
        RDD rows = out.getBinaryBlockedRDD("output").map((Function)new GetMLBlock()).rdd();
        StructType schema = MLBlock.getDefaultSchemaForBinaryBlock();
        MatrixCharacteristics mcOut = out.getMatrixCharacteristics("output");
        return new MLMatrix((Dataset<Row>)this.sparkSession().createDataFrame(rows.toJavaRDD(), schema), mcOut, this.ml);
    }

    private MLMatrix scalarBinaryOp(Double scalar, String op, boolean isScalarLeft) throws IOException, DMLException {
        this.ml.reset();
        this.ml.registerInput("left", this);
        this.ml.registerOutput("output");
        MLOutput out = this.ml.executeScript(this.getScalarBinaryScript(op, scalar, isScalarLeft));
        RDD rows = out.getBinaryBlockedRDD("output").map((Function)new GetMLBlock()).rdd();
        StructType schema = MLBlock.getDefaultSchemaForBinaryBlock();
        MatrixCharacteristics mcOut = out.getMatrixCharacteristics("output");
        return new MLMatrix((Dataset<Row>)this.sparkSession().createDataFrame(rows.toJavaRDD(), schema), mcOut, this.ml);
    }

    public MLMatrix $greater(MLMatrix that) throws IOException, DMLException {
        return this.matrixBinaryOp(that, ">");
    }

    public MLMatrix $less(MLMatrix that) throws IOException, DMLException {
        return this.matrixBinaryOp(that, "<");
    }

    public MLMatrix $greater$eq(MLMatrix that) throws IOException, DMLException {
        return this.matrixBinaryOp(that, ">=");
    }

    public MLMatrix $less$eq(MLMatrix that) throws IOException, DMLException {
        return this.matrixBinaryOp(that, "<=");
    }

    public MLMatrix $eq$eq(MLMatrix that) throws IOException, DMLException {
        return this.matrixBinaryOp(that, "==");
    }

    public MLMatrix $bang$eq(MLMatrix that) throws IOException, DMLException {
        return this.matrixBinaryOp(that, "!=");
    }

    public MLMatrix $up(MLMatrix that) throws IOException, DMLException {
        return this.matrixBinaryOp(that, "^");
    }

    public MLMatrix exp(MLMatrix that) throws IOException, DMLException {
        return this.matrixBinaryOp(that, "^");
    }

    public MLMatrix $plus(MLMatrix that) throws IOException, DMLException {
        return this.matrixBinaryOp(that, "+");
    }

    public MLMatrix add(MLMatrix that) throws IOException, DMLException {
        return this.matrixBinaryOp(that, "+");
    }

    public MLMatrix $minus(MLMatrix that) throws IOException, DMLException {
        return this.matrixBinaryOp(that, "-");
    }

    public MLMatrix minus(MLMatrix that) throws IOException, DMLException {
        return this.matrixBinaryOp(that, "-");
    }

    public MLMatrix $times(MLMatrix that) throws IOException, DMLException {
        return this.matrixBinaryOp(that, "*");
    }

    public MLMatrix elementWiseMultiply(MLMatrix that) throws IOException, DMLException {
        return this.matrixBinaryOp(that, "*");
    }

    public MLMatrix $div(MLMatrix that) throws IOException, DMLException {
        return this.matrixBinaryOp(that, "/");
    }

    public MLMatrix divide(MLMatrix that) throws IOException, DMLException {
        return this.matrixBinaryOp(that, "/");
    }

    public MLMatrix $percent$div$percent(MLMatrix that) throws IOException, DMLException {
        return this.matrixBinaryOp(that, "%/%");
    }

    public MLMatrix integerDivision(MLMatrix that) throws IOException, DMLException {
        return this.matrixBinaryOp(that, "%/%");
    }

    public MLMatrix $percent$percent(MLMatrix that) throws IOException, DMLException {
        return this.matrixBinaryOp(that, "%%");
    }

    public MLMatrix modulus(MLMatrix that) throws IOException, DMLException {
        return this.matrixBinaryOp(that, "%%");
    }

    public MLMatrix $percent$times$percent(MLMatrix that) throws IOException, DMLException {
        return this.matrixBinaryOp(that, "%*%");
    }

    public MLMatrix multiply(MLMatrix that) throws IOException, DMLException {
        return this.matrixBinaryOp(that, "%*%");
    }

    public MLMatrix transpose() throws IOException, DMLException {
        this.ml.reset();
        this.ml.registerInput("left", this);
        this.ml.registerOutput("output");
        String script = "left = read(\"\");output = t(left); " + writeStmt;
        MLOutput out = this.ml.executeScript(script);
        RDD rows = out.getBinaryBlockedRDD("output").map((Function)new GetMLBlock()).rdd();
        StructType schema = MLBlock.getDefaultSchemaForBinaryBlock();
        MatrixCharacteristics mcOut = out.getMatrixCharacteristics("output");
        return new MLMatrix((Dataset<Row>)this.sparkSession().createDataFrame(rows.toJavaRDD(), schema), mcOut, this.ml);
    }

    public MLMatrix $plus(Double scalar) throws IOException, DMLException {
        return this.scalarBinaryOp(scalar, "+", false);
    }

    public MLMatrix add(Double scalar) throws IOException, DMLException {
        return this.scalarBinaryOp(scalar, "+", false);
    }

    public MLMatrix $minus(Double scalar) throws IOException, DMLException {
        return this.scalarBinaryOp(scalar, "-", false);
    }

    public MLMatrix minus(Double scalar) throws IOException, DMLException {
        return this.scalarBinaryOp(scalar, "-", false);
    }

    public MLMatrix $times(Double scalar) throws IOException, DMLException {
        return this.scalarBinaryOp(scalar, "*", false);
    }

    public MLMatrix elementWiseMultiply(Double scalar) throws IOException, DMLException {
        return this.scalarBinaryOp(scalar, "*", false);
    }

    public MLMatrix $div(Double scalar) throws IOException, DMLException {
        return this.scalarBinaryOp(scalar, "/", false);
    }

    public MLMatrix divide(Double scalar) throws IOException, DMLException {
        return this.scalarBinaryOp(scalar, "/", false);
    }

    public MLMatrix $greater(Double scalar) throws IOException, DMLException {
        return this.scalarBinaryOp(scalar, ">", false);
    }

    public MLMatrix $less(Double scalar) throws IOException, DMLException {
        return this.scalarBinaryOp(scalar, "<", false);
    }

    public MLMatrix $greater$eq(Double scalar) throws IOException, DMLException {
        return this.scalarBinaryOp(scalar, ">=", false);
    }

    public MLMatrix $less$eq(Double scalar) throws IOException, DMLException {
        return this.scalarBinaryOp(scalar, "<=", false);
    }

    public MLMatrix $eq$eq(Double scalar) throws IOException, DMLException {
        return this.scalarBinaryOp(scalar, "==", false);
    }

    public MLMatrix $bang$eq(Double scalar) throws IOException, DMLException {
        return this.scalarBinaryOp(scalar, "!=", false);
    }
}

