Custom Operator for NMF Decomposition#

NMF factorizes an input matrix into two matrices W, H of rank k so that WH \sim M`. M=(m_{ij}) may be a binary matrix where i is a user and j a product he bought. The prediction function depends on whether or not the user needs a recommandation for an existing user or a new user. This example addresses the first case.

The second case is more complex as it theoretically requires the estimation of a new matrix W with a gradient descent.

Building a simple model#

import os
import skl2onnx
import onnxruntime
import sklearn
from sklearn.decomposition import NMF
import numpy as np
import matplotlib.pyplot as plt
from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer
import onnx
from skl2onnx.algebra.onnx_ops import (
    OnnxArrayFeatureExtractor, OnnxMul, OnnxReduceSum)
from skl2onnx.common.data_types import FloatTensorType
from onnxruntime import InferenceSession


mat = np.array([[1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0],
                [1, 0, 0, 0], [1, 0, 0, 0]], dtype=np.float64)
mat[:mat.shape[1], :] += np.identity(mat.shape[1])

mod = NMF(n_components=2)
W = mod.fit_transform(mat)
H = mod.components_
pred = mod.inverse_transform(W)

print("original predictions")
exp = []
for i in range(mat.shape[0]):
    for j in range(mat.shape[1]):
        exp.append((i, j, pred[i, j]))

print(exp)
original predictions
[(0, 0, 1.8940570076830285), (0, 1, 0.3072441822407282), (0, 2, 0.3072441822407282), (0, 3, 0.10911047375804787), (1, 0, 1.1066071879294734), (1, 1, 0.1908338542786808), (1, 2, 0.1908338542786808), (1, 3, 0.0), (2, 0, 1.1066071879294734), (2, 1, 0.1908338542786808), (2, 2, 0.1908338542786808), (2, 3, 0.0), (3, 0, 1.0146710371562229), (3, 1, 0.0), (3, 2, 0.0), (3, 3, 0.9848903284716739), (4, 0, 0.9470285038415143), (4, 1, 0.1536220911203641), (4, 2, 0.1536220911203641), (4, 3, 0.05455523687902394)]

Let’s rewrite the prediction in a way it is closer to the function we need to convert into ONNX.

def predict(W, H, row_index, col_index):
    return np.dot(W[row_index, :], H[:, col_index])


got = []
for i in range(mat.shape[0]):
    for j in range(mat.shape[1]):
        got.append((i, j, predict(W, H, i, j)))

print(got)
[(0, 0, 1.8940570076830285), (0, 1, 0.3072441822407282), (0, 2, 0.3072441822407282), (0, 3, 0.10911047375804787), (1, 0, 1.1066071879294734), (1, 1, 0.1908338542786808), (1, 2, 0.1908338542786808), (1, 3, 0.0), (2, 0, 1.1066071879294734), (2, 1, 0.1908338542786808), (2, 2, 0.1908338542786808), (2, 3, 0.0), (3, 0, 1.0146710371562229), (3, 1, 0.0), (3, 2, 0.0), (3, 3, 0.9848903284716739), (4, 0, 0.9470285038415143), (4, 1, 0.1536220911203641), (4, 2, 0.1536220911203641), (4, 3, 0.05455523687902394)]

Conversion into ONNX#

There is no implemented converter for NMF as the function we plan to convert is not transformer or a predictor. The following converter does not need to be registered, it just creates an ONNX graph equivalent to function predict implemented above.

def nmf_to_onnx(W, H, op_version=12):
    """
    The function converts a NMF described by matrices
    *W*, *H* (*WH* approximate training data *M*).
    into a function which takes two indices *(i, j)*
    and returns the predictions for it. It assumes
    these indices applies on the training data.
    """
    col = OnnxArrayFeatureExtractor(H, 'col')
    row = OnnxArrayFeatureExtractor(W.T, 'row')
    dot = OnnxMul(col, row, op_version=op_version)
    res = OnnxReduceSum(dot, output_names="rec", op_version=op_version)
    indices_type = np.array([0], dtype=np.int64)
    onx = res.to_onnx(inputs={'col': indices_type,
                              'row': indices_type},
                      outputs=[('rec', FloatTensorType((None, 1)))],
                      target_opset=op_version)
    return onx


model_onnx = nmf_to_onnx(W.astype(np.float32),
                         H.astype(np.float32))
print(model_onnx)
ir_version: 7
producer_name: "skl2onnx"
producer_version: "1.14.0"
domain: "ai.onnx"
model_version: 0
graph {
  node {
    input: "Ar_ArrayFeatureExtractorcst"
    input: "col"
    output: "Ar_Z0"
    name: "Ar_ArrayFeatureExtractor"
    op_type: "ArrayFeatureExtractor"
    domain: "ai.onnx.ml"
  }
  node {
    input: "Ar_ArrayFeatureExtractorcst1"
    input: "row"
    output: "Ar_Z02"
    name: "Ar_ArrayFeatureExtractor1"
    op_type: "ArrayFeatureExtractor"
    domain: "ai.onnx.ml"
  }
  node {
    input: "Ar_Z0"
    input: "Ar_Z02"
    output: "Mu_C0"
    name: "Mu_Mul"
    op_type: "Mul"
    domain: ""
  }
  node {
    input: "Mu_C0"
    output: "rec"
    name: "Re_ReduceSum"
    op_type: "ReduceSum"
    domain: ""
  }
  name: "OnnxReduceSum"
  initializer {
    dims: 2
    dims: 4
    data_type: 1
    float_data: 1.9789423942565918
    float_data: 0.3412676155567169
    float_data: 0.3412676155567169
    float_data: 0.0
    float_data: 0.8960736989974976
    float_data: 0.0
    float_data: 0.0
    float_data: 0.869773805141449
    name: "Ar_ArrayFeatureExtractorcst"
  }
  initializer {
    dims: 2
    dims: 5
    data_type: 1
    float_data: 0.9003027677536011
    float_data: 0.5591912269592285
    float_data: 0.5591912269592285
    float_data: 0.0
    float_data: 0.45015138387680054
    float_data: 0.12544694542884827
    float_data: 0.0
    float_data: 0.0
    float_data: 1.1323522329330444
    float_data: 0.06272347271442413
    name: "Ar_ArrayFeatureExtractorcst1"
  }
  input {
    name: "col"
    type {
      tensor_type {
        elem_type: 7
        shape {
          dim {
          }
        }
      }
    }
  }
  input {
    name: "row"
    type {
      tensor_type {
        elem_type: 7
        shape {
          dim {
          }
        }
      }
    }
  }
  output {
    name: "rec"
    type {
      tensor_type {
        elem_type: 1
        shape {
          dim {
          }
          dim {
            dim_value: 1
          }
        }
      }
    }
  }
}
opset_import {
  domain: ""
  version: 12
}
opset_import {
  domain: "ai.onnx.ml"
  version: 1
}

Let’s compute prediction with it.

sess = InferenceSession(model_onnx.SerializeToString())


def predict_onnx(sess, row_indices, col_indices):
    res = sess.run(None,
                   {'col': col_indices,
                    'row': row_indices})
    return res


onnx_preds = []
for i in range(mat.shape[0]):
    for j in range(mat.shape[1]):
        row_indices = np.array([i], dtype=np.int64)
        col_indices = np.array([j], dtype=np.int64)
        pred = predict_onnx(sess, row_indices, col_indices)[0]
        onnx_preds.append((i, j, pred[0, 0]))

print(onnx_preds)
[(0, 0, 1.894057), (0, 1, 0.30724418), (0, 2, 0.30724418), (0, 3, 0.10911047), (1, 0, 1.1066072), (1, 1, 0.19083385), (1, 2, 0.19083385), (1, 3, 0.0), (2, 0, 1.1066072), (2, 1, 0.19083385), (2, 2, 0.19083385), (2, 3, 0.0), (3, 0, 1.0146711), (3, 1, 0.0), (3, 2, 0.0), (3, 3, 0.9848903), (4, 0, 0.9470285), (4, 1, 0.15362209), (4, 2, 0.15362209), (4, 3, 0.054555234)]

The ONNX graph looks like the following.

pydot_graph = GetPydotGraph(
    model_onnx.graph, name=model_onnx.graph.name,
    rankdir="TB", node_producer=GetOpNodeProducer("docstring"))
pydot_graph.write_dot("graph_nmf.dot")
os.system('dot -O -Tpng graph_nmf.dot')
image = plt.imread("graph_nmf.dot.png")
plt.imshow(image)
plt.axis('off')
plot nmf
(-0.5, 1276.5, 846.5, -0.5)

Versions used for this example

print("numpy:", np.__version__)
print("scikit-learn:", sklearn.__version__)
print("onnx: ", onnx.__version__)
print("onnxruntime: ", onnxruntime.__version__)
print("skl2onnx: ", skl2onnx.__version__)
numpy: 1.23.5
scikit-learn: 1.2.2
onnx:  1.13.1
onnxruntime:  1.14.1
skl2onnx:  1.14.0

Total running time of the script: ( 0 minutes 1.230 seconds)

Gallery generated by Sphinx-Gallery