/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.scripts.nn.layers;

import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import org.apache.sysml.api.mlcontext.MLResults;
import org.apache.sysml.api.mlcontext.Matrix;
import org.apache.sysml.api.mlcontext.Script;
import org.apache.sysml.scripts.nn.layers.fm.Backward_output;
import org.apache.sysml.scripts.nn.layers.fm.Init_output;

public class Fm
extends Script {
    public Fm() {
        String string = "scripts/nn/layers/fm.dml";
        InputStream inputStream = Script.class.getResourceAsStream(new StringBuffer().append("/").append(string).toString());
        InputStreamReader inputStreamReader = new InputStreamReader(inputStream);
        char[] cArray = new char[1024];
        StringBuilder stringBuilder = new StringBuilder();
        try {
            int n;
            while ((n = inputStreamReader.read(cArray)) > 0) {
                stringBuilder.append(cArray, 0, n);
            }
        }
        catch (IOException iOException) {
            iOException.printStackTrace();
        }
        this.setScriptString(stringBuilder.toString());
    }

    public Init_output init(Object object, Object object2, Object object3) {
        String string = "source('scripts/nn/layers/fm.dml') as mlcontextns;[w0, W, V] = mlcontextns::init(n, d, k);";
        Script script = new Script(string);
        script.in("n", object).in("d", object2).in("k", object3).out("w0").out("W").out("V");
        MLResults mLResults = script.execute();
        Matrix matrix = mLResults.getMatrix("w0");
        Matrix matrix2 = mLResults.getMatrix("W");
        Matrix matrix3 = mLResults.getMatrix("V");
        Init_output init_output = new Init_output(matrix, matrix2, matrix3);
        return init_output;
    }

    public String init__docs() {
        String string = "init = function(int n, int d, int k)\n    return (matrix[double] w0, matrix[double] W, matrix[double] V) {\n  /*\n   * This function initializes the parameters.\n   *\n   * Inputs:\n   *  - d: the number of features, is an integer.\n   *  - k: the factorization dimensionality, is an integer.\n   *\n   * Outputs:\n   *  - w0: the global bias, of shape (1,).\n   *  - W : the strength of each feature, of shape (d, 1).\n   *  - V : factorized interaction terms, of shape (d, k).\n   */\n";
        return string;
    }

    public String init__source() {
        String string = "init = function(int n, int d, int k)\n    return (matrix[double] w0, matrix[double] W, matrix[double] V) {\n  /*\n   * This function initializes the parameters.\n   *\n   * Inputs:\n   *  - d: the number of features, is an integer.\n   *  - k: the factorization dimensionality, is an integer.\n   *\n   * Outputs:\n   *  - w0: the global bias, of shape (1,).\n   *  - W : the strength of each feature, of shape (d, 1).\n   *  - V : factorized interaction terms, of shape (d, k).\n   */\n  w0 = matrix(0, rows=1, cols=1)\n  W  = matrix(0, rows=d, cols=1)\n  V  = rand(rows=d, cols=k, min=0.0, max=1.0, pdf=\"uniform\", sparsity=.08)\n}\n";
        return string;
    }

    public Matrix forward(Object object, Object object2, Object object3, Object object4) {
        String string = "source('scripts/nn/layers/fm.dml') as mlcontextns;out = mlcontextns::forward(X, w0, W, V);";
        Script script = new Script(string);
        script.in("X", object).in("w0", object2).in("W", object3).in("V", object4).out("out");
        MLResults mLResults = script.execute();
        Matrix matrix = mLResults.getMatrix("out");
        return matrix;
    }

    public String forward__docs() {
        String string = "forward = function(matrix[double] X, matrix[double] w0, matrix[double] W, matrix[double] V)\n    return (matrix[double] out) {\n  /*\n   * Computes the model.\n   *\n   * Reference:\n   *  - Factorization Machines, Steffen Rendle.\n   *\n   * Inputs:\n   *  - X : n examples with d features, of shape (n, d).\n   *  - w0: the global bias, of shape (1,).\n   *  - W : the strength of each feature, of shape (d, 1).\n   *  - V : factorized interaction terms, of shape (d, k).\n   *\n   * Outputs:\n   *  - out : target vector, of shape (n, 1).\n   */\n";
        return string;
    }

    public String forward__source() {
        String string = "forward = function(matrix[double] X, matrix[double] w0, matrix[double] W, matrix[double] V)\n    return (matrix[double] out) {\n  /*\n   * Computes the model.\n   *\n   * Reference:\n   *  - Factorization Machines, Steffen Rendle.\n   *\n   * Inputs:\n   *  - X : n examples with d features, of shape (n, d).\n   *  - w0: the global bias, of shape (1,).\n   *  - W : the strength of each feature, of shape (d, 1).\n   *  - V : factorized interaction terms, of shape (d, k).\n   *\n   * Outputs:\n   *  - out : target vector, of shape (n, 1).\n   */\n  out = (X %*% W) + (0.5 * rowSums((X %*% V)^2 - (X^2 %*% V^2)) ) + w0  # shape (n, 1)\n}\n";
        return string;
    }

    public Backward_output backward(Object object, Object object2, Object object3, Object object4, Object object5) {
        String string = "source('scripts/nn/layers/fm.dml') as mlcontextns;[dw0, dW, dV] = mlcontextns::backward(dout, X, w0, W, V);";
        Script script = new Script(string);
        script.in("dout", object).in("X", object2).in("w0", object3).in("W", object4).in("V", object5).out("dw0").out("dW").out("dV");
        MLResults mLResults = script.execute();
        Matrix matrix = mLResults.getMatrix("dw0");
        Matrix matrix2 = mLResults.getMatrix("dW");
        Matrix matrix3 = mLResults.getMatrix("dV");
        Backward_output backward_output = new Backward_output(matrix, matrix2, matrix3);
        return backward_output;
    }

    public String backward__docs() {
        String string = "backward = function(matrix[double] dout, matrix[double] X, matrix[double] w0, matrix[double] W,\n                    matrix[double] V)\n    return (matrix[double] dw0, matrix[double] dW, matrix[double] dV) {\n  /*\n   * This function accepts the upstream gradients w.r.t. output target\n   * vector, and returns the gradients of the loss w.r.t. the\n   * parameters.\n   *\n   * Inputs:\n   *  - dout : the gradient of the loss function w.r.t y, of\n   *     shape (n, 1).\n   *  - X, w0, W, V are as mentioned in the above forward function.\n   *\n   * Outputs:\n   *  - dX : the gradient of loss function w.r.t  X, of shape (n, d).\n   *  - dw0: the gradient of loss function w.r.t w0, of shape (1,).\n   *  - dW : the gradient of loss function w.r.t  W, of shape (d, 1).\n   *  - dV : the gradient of loss function w.r.t  V, of shape (d, k).\n   */\n";
        return string;
    }

    public String backward__source() {
        String string = "backward = function(matrix[double] dout, matrix[double] X, matrix[double] w0, matrix[double] W,\n                    matrix[double] V)\n    return (matrix[double] dw0, matrix[double] dW, matrix[double] dV) {\n  /*\n   * This function accepts the upstream gradients w.r.t. output target\n   * vector, and returns the gradients of the loss w.r.t. the\n   * parameters.\n   *\n   * Inputs:\n   *  - dout : the gradient of the loss function w.r.t y, of\n   *     shape (n, 1).\n   *  - X, w0, W, V are as mentioned in the above forward function.\n   *\n   * Outputs:\n   *  - dX : the gradient of loss function w.r.t  X, of shape (n, d).\n   *  - dw0: the gradient of loss function w.r.t w0, of shape (1,).\n   *  - dW : the gradient of loss function w.r.t  W, of shape (d, 1).\n   *  - dV : the gradient of loss function w.r.t  V, of shape (d, k).\n   */\n  n = nrow(X)\n  d = ncol(X)\n  k = ncol(V)\n\n  # 1. gradient of target vector w.r.t. w0\n  g_w0 = as.matrix(1)  # shape (1, 1)\n\n  ## gradient of loss function w.r.t. w0\n  dw0  = colSums(dout)  # shape (1, 1)\n\n  # 2. gradient target vector w.r.t. W\n  g_W = X  # shape (n, d)\n\n  ## gradient of loss function w.r.t. W\n  dW  =  t(g_W) %*% dout  # shape (d, 1)\n\n  # TODO: VECTORIZE THE FOLLOWING CODE (https://issues.apache.org/jira/browse/SYSTEMML-2102)\n  # 3. gradient of target vector w.r.t. V\n  # First term -> g_V1 = t(X) %*% (X %*% V)  # shape (d, k)\n\n  ## gradient of loss function w.r.t. V\n  # First term -> t(X) %*% X %*% V\n\n\n  # Second term -> V(i,f) * (X(i))^2\n  Xt = t( X^2 ) %*% dout  # shape (d,1)\n\n  g_V2 = Xt[1,] %*% V[1,]\n\n  for (i in 2:d) {\n    tmp = Xt[i,] %*% V[i,]\n    g_V2 = rbind(g_V2, tmp)\n  }\n\n  xv = X %*% V\n\n  g_V1 = dout[,1] * xv[,1]\n\n  for (j in 2:k) {\n    tmp1 = dout[,1] * xv[,k]\n    g_V1 = cbind(g_V1, tmp1)\n  }\n\n  dV = (t(X) %*% g_V1) - g_V2\n  # dV = mean(dout) * (t(X) %*% X %*%V) - g_V2\n}\n";
        return string;
    }
}

