ONNX graph, single or double floats#

Links: notebook, html, PDF, python, slides, GitHub

The notebook shows discrepencies obtained by using double floats instead of single float in two cases. The second one involves GaussianProcessRegressor.

from jyquickhelper import add_notebook_menu
add_notebook_menu()

Simple case of a linear regression#

A linear regression is simply a matrix multiplication followed by an addition: Y=AX+B. Let’s train one with scikit-learn.

from sklearn.linear_model import LinearRegression
from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split
data = load_diabetes()
X, y = data.data, data.target
X_train, X_test, y_train, y_test = train_test_split(X, y)
clr = LinearRegression()
clr.fit(X_train, y_train)
LinearRegression()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
clr.score(X_test, y_test)
0.48022823853163243
clr.coef_
array([  -3.66884712, -248.12455809,  503.47675603,  314.42722272,
       -937.79829646,  589.5139395 ,  166.9937767 ,  238.52080461,
        810.51985926,   83.1649252 ])
clr.intercept_
151.72119345267856

Let’s predict with scikit-learn and python.

ypred = clr.predict(X_test)
ypred[:5]
array([ 65.19089869, 136.63206471, 197.78320816,  76.50979441,
       120.17048032])
py_pred = X_test @ clr.coef_ + clr.intercept_
py_pred[:5]
array([ 65.19089869, 136.63206471, 197.78320816,  76.50979441,
       120.17048032])
clr.coef_.dtype, clr.intercept_.dtype
(dtype('float64'), dtype('float64'))

With ONNX#

With ONNX, we would write this operation as follows… We still need to convert everything into single floats = float32.

%load_ext mlprodict
from skl2onnx.algebra.onnx_ops import OnnxMatMul, OnnxAdd
import numpy

onnx_fct = OnnxAdd(OnnxMatMul('X', clr.coef_.astype(numpy.float32), op_version=12),
                   numpy.array([clr.intercept_], dtype=numpy.float32),
                   output_names=['Y'], op_version=12)
onnx_model32 = onnx_fct.to_onnx({'X': X_test.astype(numpy.float32)})

# add -l 1 if nothing shows up
%onnxview onnx_model32

The next line uses a python runtime to compute the prediction.

from mlprodict.onnxrt import OnnxInference
oinf = OnnxInference(onnx_model32, inplace=False)
ort_pred = oinf.run({'X': X_test.astype(numpy.float32)})['Y']
ort_pred[:5]
array([ 65.190895, 136.63206 , 197.7832  ,  76.509796, 120.17048 ],
      dtype=float32)

And here is the same with onnxruntime

oinf = OnnxInference(onnx_model32, runtime="onnxruntime1")
ort_pred = oinf.run({'X': X_test.astype(numpy.float32)})['Y']
ort_pred[:5]
array([ 65.190895, 136.63206 , 197.7832  ,  76.509796, 120.17048 ],
      dtype=float32)

With double instead of single float#

ONNX was originally designed for deep learning which usually uses floats but it does not mean cannot be used. Every number is converted into double floats.

onnx_fct = OnnxAdd(OnnxMatMul('X', clr.coef_.astype(numpy.float64), op_version=12),
                   numpy.array([clr.intercept_], dtype=numpy.float64),
                   output_names=['Y'], op_version=12)
onnx_model64 = onnx_fct.to_onnx({'X': X_test.astype(numpy.float64)})

And now the python runtime…

oinf = OnnxInference(onnx_model64)
ort_pred = oinf.run({'X': X_test})['Y']
ort_pred[:5]
array([ 65.19089869, 136.63206471, 197.78320816,  76.50979441,
       120.17048032])

And the onnxruntime version of it.

oinf = OnnxInference(onnx_model64, runtime="onnxruntime1")
ort_pred = oinf.run({'X': X_test.astype(numpy.float64)})['Y']
ort_pred[:5]
array([ 65.19089869, 136.63206471, 197.78320816,  76.50979441,
       120.17048032])

And now the GaussianProcessRegressor#

This shows a case

from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import DotProduct
gau = GaussianProcessRegressor(alpha=10, kernel=DotProduct())
gau.fit(X_train, y_train)
GaussianProcessRegressor(alpha=10, kernel=DotProduct(sigma_0=1))
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
from mlprodict.onnx_conv import to_onnx
onnxgau32 = to_onnx(gau, X_train.astype(numpy.float32))
oinf32 = OnnxInference(onnxgau32, runtime="python", inplace=False)
ort_pred32 = oinf32.run({'X': X_test.astype(numpy.float32)})['GPmean']
numpy.squeeze(ort_pred32)[:25]
array([136.    , 146.75  , 156.875 , 137.625 , 143.6875, 157.25  ,
       137.625 , 155.4375, 157.125 , 176.1875, 154.    , 144.6875,
       152.875 , 163.0625, 134.5   , 169.25  , 143.4375, 156.    ,
       147.9375, 147.5625, 143.5625, 139.5   , 167.3125, 162.8125,
       157.5   ], dtype=float32)
onnxgau64 = to_onnx(gau, X_train.astype(numpy.float64))
oinf64 = OnnxInference(onnxgau64, runtime="python", inplace=False)
ort_pred64 = oinf64.run({'X': X_test.astype(numpy.float64)})['GPmean']
numpy.squeeze(ort_pred64)[:25]
array([136.29042094, 147.37000865, 157.17181659, 137.37942361,
       143.75809938, 157.26946743, 138.0470418 , 155.13779478,
       157.13725317, 176.25699851, 154.58148006, 144.76382797,
       152.92400576, 162.55328615, 135.01672829, 169.57752091,
       144.15882691, 155.9305585 , 147.74172845, 147.95694225,
       143.58627788, 139.44744308, 167.34231253, 162.89442931,
       157.77991459])

The differences between the predictions for single floats and double floats…

numpy.sort(numpy.sort(numpy.squeeze(ort_pred32 - ort_pred64)))[-5:]
array([0.35428989, 0.37583714, 0.39413358, 0.46870174, 0.50921385])

Who’s right or wrong… The differences between the predictions with the original model…

pred = gau.predict(X_test.astype(numpy.float64))
numpy.sort(numpy.sort(numpy.squeeze(ort_pred32 - pred)))[-5:]
array([[-2.53819985e+01, -2.50722714e+01, -2.14450449e+01,
        -2.00524647e+01, -1.95723838e+01, -1.87025209e+01,
        -1.64673125e+01, -1.59125835e+01, -1.55697413e+01,
        -1.55154512e+01, -1.50184911e+01, -1.41483318e+01,
        -1.40623681e+01, -1.30606141e+01, -1.30289437e+01,
        -1.28485784e+01, -1.26562983e+01, -1.20194293e+01,
        -1.16782862e+01, -1.11241629e+01, -1.03175536e+01,
        -9.62974691e+00, -9.62801189e+00, -9.48419875e+00,
        -9.35995816e+00, -9.17521553e+00, -8.87705881e+00,
        -8.70495702e+00, -7.75760088e+00, -7.16953019e+00,
        -7.12544216e+00, -6.90491459e+00, -6.63742534e+00,
        -6.62542751e+00, -6.56554431e+00, -6.49066331e+00,
        -6.39446743e+00, -6.35915325e+00, -6.32373986e+00,
        -6.29681659e+00, -6.26225317e+00, -6.21640097e+00,
        -5.27455577e+00, -5.05555850e+00, -4.26279478e+00,
        -3.76824125e+00, -3.70648006e+00, -3.56134387e+00,
        -3.53965309e+00, -3.40999971e+00, -3.24941609e+00,
        -2.51781226e+00, -2.23055070e+00, -2.17549279e+00,
        -2.04900576e+00, -1.56856141e+00, -4.25747090e-01,
        -7.60603113e-02,  4.36406069e-01,  1.66633897e+00,
         2.23583238e+00,  2.34094888e+00,  2.40322693e+00,
         2.91733762e+00,  2.91805775e+00,  3.13327155e+00,
         3.16460532e+00,  3.22593062e+00,  3.47124374e+00,
         3.50499135e+00,  3.77589042e+00,  4.66605891e+00,
         4.67679863e+00,  4.86402106e+00,  4.91204912e+00,
         5.25337645e+00,  5.51913358e+00,  5.52309500e+00,
         6.11117203e+00,  6.24359920e+00,  6.71617309e+00,
         6.74395753e+00,  7.02078654e+00,  7.11690062e+00,
         7.28872212e+00,  8.61841240e+00,  8.78399897e+00,
         9.91193494e+00,  1.05662126e+01,  1.07342021e+01,
         1.08129870e+01,  1.10002669e+01,  1.13923133e+01,
         1.14275569e+01,  1.16316224e+01,  1.18395817e+01,
         1.23588746e+01,  1.26277069e+01,  1.28279582e+01,
         1.34955764e+01,  1.41485331e+01,  1.45845791e+01,
         1.47292899e+01,  1.52692408e+01,  1.56045809e+01,
         1.58122364e+01,  1.58582717e+01,  1.58897185e+01,
         1.81760154e+01,  1.83160274e+01,  1.87306590e+01],
       [-3.06319985e+01, -3.03222714e+01, -2.66950449e+01,
        -2.53024647e+01, -2.48223838e+01, -2.39525209e+01,
        -2.17173125e+01, -2.11625835e+01, -2.08197413e+01,
        -2.07654512e+01, -2.02684911e+01, -1.93983318e+01,
        -1.93123681e+01, -1.83106141e+01, -1.82789437e+01,
        -1.80985784e+01, -1.79062983e+01, -1.72694293e+01,
        -1.69282862e+01, -1.63741629e+01, -1.55675536e+01,
        -1.48797469e+01, -1.48780119e+01, -1.47341988e+01,
        -1.46099582e+01, -1.44252155e+01, -1.41270588e+01,
        -1.39549570e+01, -1.30076009e+01, -1.24195302e+01,
        -1.23754422e+01, -1.21549146e+01, -1.18874253e+01,
        -1.18754275e+01, -1.18155443e+01, -1.17406633e+01,
        -1.16444674e+01, -1.16091533e+01, -1.15737399e+01,
        -1.15468166e+01, -1.15122532e+01, -1.14664010e+01,
        -1.05245558e+01, -1.03055585e+01, -9.51279478e+00,
        -9.01824125e+00, -8.95648006e+00, -8.81134387e+00,
        -8.78965309e+00, -8.65999971e+00, -8.49941609e+00,
        -7.76781226e+00, -7.48055070e+00, -7.42549279e+00,
        -7.29900576e+00, -6.81856141e+00, -5.67574709e+00,
        -5.32606031e+00, -4.81359393e+00, -3.58366103e+00,
        -3.01416762e+00, -2.90905112e+00, -2.84677307e+00,
        -2.33266238e+00, -2.33194225e+00, -2.11672845e+00,
        -2.08539468e+00, -2.02406938e+00, -1.77875626e+00,
        -1.74500865e+00, -1.47410958e+00, -5.83941087e-01,
        -5.73201374e-01, -3.85978940e-01, -3.37950879e-01,
         3.37644562e-03,  2.69133580e-01,  2.73095000e-01,
         8.61172031e-01,  9.93599197e-01,  1.46617309e+00,
         1.49395753e+00,  1.77078654e+00,  1.86690062e+00,
         2.03872212e+00,  3.36841240e+00,  3.53399897e+00,
         4.66193494e+00,  5.31621258e+00,  5.48420209e+00,
         5.56298697e+00,  5.75026691e+00,  6.14231332e+00,
         6.17755692e+00,  6.38162237e+00,  6.58958168e+00,
         7.10887459e+00,  7.37770688e+00,  7.57795820e+00,
         8.24557639e+00,  8.89853314e+00,  9.33457906e+00,
         9.47928989e+00,  1.00192408e+01,  1.03545809e+01,
         1.05622364e+01,  1.06082717e+01,  1.06397185e+01,
         1.29260154e+01,  1.30660274e+01,  1.34806590e+01],
       [-4.40069985e+01, -4.36972714e+01, -4.00700449e+01,
        -3.86774647e+01, -3.81973838e+01, -3.73275209e+01,
        -3.50923125e+01, -3.45375835e+01, -3.41947413e+01,
        -3.41404512e+01, -3.36434911e+01, -3.27733318e+01,
        -3.26873681e+01, -3.16856141e+01, -3.16539437e+01,
        -3.14735784e+01, -3.12812983e+01, -3.06444293e+01,
        -3.03032862e+01, -2.97491629e+01, -2.89425536e+01,
        -2.82547469e+01, -2.82530119e+01, -2.81091988e+01,
        -2.79849582e+01, -2.78002155e+01, -2.75020588e+01,
        -2.73299570e+01, -2.63826009e+01, -2.57945302e+01,
        -2.57504422e+01, -2.55299146e+01, -2.52624253e+01,
        -2.52504275e+01, -2.51905443e+01, -2.51156633e+01,
        -2.50194674e+01, -2.49841533e+01, -2.49487399e+01,
        -2.49218166e+01, -2.48872532e+01, -2.48414010e+01,
        -2.38995558e+01, -2.36805585e+01, -2.28877948e+01,
        -2.23932412e+01, -2.23314801e+01, -2.21863439e+01,
        -2.21646531e+01, -2.20349997e+01, -2.18744161e+01,
        -2.11428123e+01, -2.08555507e+01, -2.08004928e+01,
        -2.06740058e+01, -2.01935614e+01, -1.90507471e+01,
        -1.87010603e+01, -1.81885939e+01, -1.69586610e+01,
        -1.63891676e+01, -1.62840511e+01, -1.62217731e+01,
        -1.57076624e+01, -1.57069422e+01, -1.54917284e+01,
        -1.54603947e+01, -1.53990694e+01, -1.51537563e+01,
        -1.51200087e+01, -1.48491096e+01, -1.39589411e+01,
        -1.39482014e+01, -1.37609789e+01, -1.37129509e+01,
        -1.33716236e+01, -1.31058664e+01, -1.31019050e+01,
        -1.25138280e+01, -1.23814008e+01, -1.19088269e+01,
        -1.18810425e+01, -1.16042135e+01, -1.15080994e+01,
        -1.13362779e+01, -1.00065876e+01, -9.84100103e+00,
        -8.71306506e+00, -8.05878742e+00, -7.89079791e+00,
        -7.81201303e+00, -7.62473309e+00, -7.23268668e+00,
        -7.19744308e+00, -6.99337763e+00, -6.78541832e+00,
        -6.26612541e+00, -5.99729312e+00, -5.79704180e+00,
        -5.12942361e+00, -4.47646686e+00, -4.04042094e+00,
        -3.89571011e+00, -3.35575918e+00, -3.02041912e+00,
        -2.81276360e+00, -2.76672829e+00, -2.73528153e+00,
        -4.48984564e-01, -3.08972594e-01,  1.05659015e-01],
       [-2.76319985e+01, -2.73222714e+01, -2.36950449e+01,
        -2.23024647e+01, -2.18223838e+01, -2.09525209e+01,
        -1.87173125e+01, -1.81625835e+01, -1.78197413e+01,
        -1.77654512e+01, -1.72684911e+01, -1.63983318e+01,
        -1.63123681e+01, -1.53106141e+01, -1.52789437e+01,
        -1.50985784e+01, -1.49062983e+01, -1.42694293e+01,
        -1.39282862e+01, -1.33741629e+01, -1.25675536e+01,
        -1.18797469e+01, -1.18780119e+01, -1.17341988e+01,
        -1.16099582e+01, -1.14252155e+01, -1.11270588e+01,
        -1.09549570e+01, -1.00076009e+01, -9.41953019e+00,
        -9.37544216e+00, -9.15491459e+00, -8.88742534e+00,
        -8.87542751e+00, -8.81554431e+00, -8.74066331e+00,
        -8.64446743e+00, -8.60915325e+00, -8.57373986e+00,
        -8.54681659e+00, -8.51225317e+00, -8.46640097e+00,
        -7.52455577e+00, -7.30555850e+00, -6.51279478e+00,
        -6.01824125e+00, -5.95648006e+00, -5.81134387e+00,
        -5.78965309e+00, -5.65999971e+00, -5.49941609e+00,
        -4.76781226e+00, -4.48055070e+00, -4.42549279e+00,
        -4.29900576e+00, -3.81856141e+00, -2.67574709e+00,
        -2.32606031e+00, -1.81359393e+00, -5.83661025e-01,
        -1.41676202e-02,  9.09488820e-02,  1.53226928e-01,
         6.67337617e-01,  6.68057753e-01,  8.83271551e-01,
         9.14605317e-01,  9.75930622e-01,  1.22124374e+00,
         1.25499135e+00,  1.52589042e+00,  2.41605891e+00,
         2.42679863e+00,  2.61402106e+00,  2.66204912e+00,
         3.00337645e+00,  3.26913358e+00,  3.27309500e+00,
         3.86117203e+00,  3.99359920e+00,  4.46617309e+00,
         4.49395753e+00,  4.77078654e+00,  4.86690062e+00,
         5.03872212e+00,  6.36841240e+00,  6.53399897e+00,
         7.66193494e+00,  8.31621258e+00,  8.48420209e+00,
         8.56298697e+00,  8.75026691e+00,  9.14231332e+00,
         9.17755692e+00,  9.38162237e+00,  9.58958168e+00,
         1.01088746e+01,  1.03777069e+01,  1.05779582e+01,
         1.12455764e+01,  1.18985331e+01,  1.23345791e+01,
         1.24792899e+01,  1.30192408e+01,  1.33545809e+01,
         1.35622364e+01,  1.36082717e+01,  1.36397185e+01,
         1.59260154e+01,  1.60660274e+01,  1.64806590e+01],
       [-1.05069985e+01, -1.01972714e+01, -6.57004494e+00,
        -5.17746472e+00, -4.69738380e+00, -3.82752091e+00,
        -1.59231253e+00, -1.03758348e+00, -6.94741336e-01,
        -6.40451168e-01, -1.43491072e-01,  7.26668170e-01,
         8.12631942e-01,  1.81438594e+00,  1.84605634e+00,
         2.02642158e+00,  2.21870174e+00,  2.85557069e+00,
         3.19671385e+00,  3.75083714e+00,  4.55744644e+00,
         5.24525309e+00,  5.24698811e+00,  5.39080125e+00,
         5.51504184e+00,  5.69978447e+00,  5.99794119e+00,
         6.17004298e+00,  7.11739912e+00,  7.70546981e+00,
         7.74955784e+00,  7.97008541e+00,  8.23757466e+00,
         8.24957249e+00,  8.30945569e+00,  8.38433669e+00,
         8.48053257e+00,  8.51584675e+00,  8.55126014e+00,
         8.57818341e+00,  8.61274683e+00,  8.65859903e+00,
         9.60044423e+00,  9.81944150e+00,  1.06122052e+01,
         1.11067588e+01,  1.11685199e+01,  1.13136561e+01,
         1.13353469e+01,  1.14650003e+01,  1.16255839e+01,
         1.23571877e+01,  1.26444493e+01,  1.26995072e+01,
         1.28259942e+01,  1.33064386e+01,  1.44492529e+01,
         1.47989397e+01,  1.53114061e+01,  1.65413390e+01,
         1.71108324e+01,  1.72159489e+01,  1.72782269e+01,
         1.77923376e+01,  1.77930578e+01,  1.80082716e+01,
         1.80396053e+01,  1.81009306e+01,  1.83462437e+01,
         1.83799913e+01,  1.86508904e+01,  1.95410589e+01,
         1.95517986e+01,  1.97390211e+01,  1.97870491e+01,
         2.01283764e+01,  2.03941336e+01,  2.03980950e+01,
         2.09861720e+01,  2.11185992e+01,  2.15911731e+01,
         2.16189575e+01,  2.18957865e+01,  2.19919006e+01,
         2.21637221e+01,  2.34934124e+01,  2.36589990e+01,
         2.47869349e+01,  2.54412126e+01,  2.56092021e+01,
         2.56879870e+01,  2.58752669e+01,  2.62673133e+01,
         2.63025569e+01,  2.65066224e+01,  2.67145817e+01,
         2.72338746e+01,  2.75027069e+01,  2.77029582e+01,
         2.83705764e+01,  2.90235331e+01,  2.94595791e+01,
         2.96042899e+01,  3.01442408e+01,  3.04795809e+01,
         3.06872364e+01,  3.07332717e+01,  3.07647185e+01,
         3.30510154e+01,  3.31910274e+01,  3.36056590e+01]])
numpy.sort(numpy.sort(numpy.squeeze(ort_pred64 - pred)))[-5:]
array([[-25.3059382 , -24.9962111 , -21.36898463, -19.9764044 ,
        -19.49632349, -18.6264606 , -16.39125222, -15.83652317,
        -15.49368102, -15.43939086, -14.94243076, -14.07227152,
        -13.98630775, -12.98455375, -12.95288335, -12.77251811,
        -12.58023795, -11.943369  , -11.60222584, -11.04810255,
        -10.24149325,  -9.5536866 ,  -9.55195157,  -9.40813844,
         -9.28389785,  -9.09915522,  -8.80099849,  -8.62889671,
         -7.68154057,  -7.09346988,  -7.04938185,  -6.82885428,
         -6.56136503,  -6.54936719,  -6.489484  ,  -6.414603  ,
         -6.31840711,  -6.28309294,  -6.24767955,  -6.22075628,
         -6.18619286,  -6.14034066,  -5.19849546,  -4.97949819,
         -4.18673447,  -3.69218093,  -3.63041975,  -3.48528356,
         -3.46359277,  -3.33393939,  -3.17335578,  -2.44175195,
         -2.15449039,  -2.09943248,  -1.97294545,  -1.4925011 ,
         -0.34968678,   0.        ,   0.51246638,   1.74239929,
          2.31189269,   2.41700919,   2.47928724,   2.99339793,
          2.99411806,   3.20933186,   3.24066563,   3.30199093,
          3.54730405,   3.58105166,   3.85195073,   4.74211922,
          4.75285894,   4.94008137,   4.98810943,   5.32943676,
          5.59519389,   5.59915531,   6.18723234,   6.31965951,
          6.7922334 ,   6.82001784,   7.09684685,   7.19296094,
          7.36478243,   8.69447271,   8.86005928,   9.98799525,
         10.64227289,  10.8102624 ,  10.88904728,  11.07632722,
         11.46837363,  11.50361724,  11.70768268,  11.91564199,
         12.4349349 ,  12.70376719,  12.90401852,  13.5716367 ,
         14.22459346,  14.66063937,  14.8053502 ,  15.34530113,
         15.68064119,  15.88829671,  15.93433202,  15.96577878,
         18.25207575,  18.39208772,  18.80671933],
       [-30.63537495, -30.32564785, -26.69842139, -25.30584116,
        -24.82576025, -23.95589736, -21.72068897, -21.16595993,
        -20.82311778, -20.76882761, -20.27186752, -19.40170828,
        -19.3157445 , -18.31399051, -18.28232011, -18.10195487,
        -17.90967471, -17.27280576, -16.9316626 , -16.37753931,
        -15.57093001, -14.88312335, -14.88138833, -14.7375752 ,
        -14.61333461, -14.42859198, -14.13043525, -13.95833347,
        -13.01097732, -12.42290664, -12.37881861, -12.15829103,
        -11.89080179, -11.87880395, -11.81892076, -11.74403975,
        -11.64784387, -11.6125297 , -11.57711631, -11.55019304,
        -11.51562962, -11.46977742, -10.52793221, -10.30893495,
         -9.51617123,  -9.02161769,  -8.95985651,  -8.81472032,
         -8.79302953,  -8.66337615,  -8.50279254,  -7.77118871,
         -7.48392715,  -7.42886923,  -7.3023822 ,  -6.82193786,
         -5.67912354,  -5.32943676,  -4.81697038,  -3.58703747,
         -3.01754407,  -2.91242756,  -2.85014952,  -2.33603883,
         -2.33531869,  -2.1201049 ,  -2.08877113,  -2.02744582,
         -1.78213271,  -1.7483851 ,  -1.47748603,  -0.58731753,
         -0.57657782,  -0.38935539,  -0.34132732,   0.        ,
          0.26575713,   0.26971855,   0.85779559,   0.99022275,
          1.46279664,   1.49058108,   1.76741009,   1.86352418,
          2.03534568,   3.36503596,   3.53062253,   4.6585585 ,
          5.31283614,   5.48082565,   5.55961053,   5.74689046,
          6.13893687,   6.17418048,   6.37824592,   6.58620523,
          7.10549815,   7.37433043,   7.57458176,   8.24219994,
          8.8951567 ,   9.33120262,   9.47591344,  10.01586437,
         10.35120443,  10.55885996,  10.60489527,  10.63634202,
         12.92263899,  13.06265096,  13.47728257],
       [-43.69802592, -43.38829881, -39.76107235, -38.36849212,
        -37.88841121, -37.01854832, -34.78333993, -34.22861089,
        -33.88576874, -33.83147857, -33.33451848, -32.46435924,
        -32.37839546, -31.37664147, -31.34497107, -31.16460583,
        -30.97232567, -30.33545672, -29.99431356, -29.44019027,
        -28.63358097, -27.94577431, -27.94403929, -27.80022616,
        -27.67598557, -27.49124294, -27.19308621, -27.02098443,
        -26.07362829, -25.4855576 , -25.44146957, -25.22094199,
        -24.95345275, -24.94145491, -24.88157172, -24.80669071,
        -24.71049483, -24.67518066, -24.63976727, -24.612844  ,
        -24.57828058, -24.53242838, -23.59058318, -23.37158591,
        -22.57882219, -22.08426865, -22.02250747, -21.87737128,
        -21.85568049, -21.72602711, -21.5654435 , -20.83383967,
        -20.54657811, -20.49152019, -20.36503316, -19.88458882,
        -18.7417745 , -18.39208772, -17.87962134, -16.64968843,
        -16.08019503, -15.97507852, -15.91280048, -15.39868979,
        -15.39796965, -15.18275586, -15.15142209, -15.09009678,
        -14.84478367, -14.81103606, -14.54013699, -13.64996849,
        -13.63922878, -13.45200635, -13.40397829, -13.06265096,
        -12.79689383, -12.79293241, -12.20485538, -12.07242821,
        -11.59985432, -11.57206988, -11.29524087, -11.19912678,
        -11.02730529,  -9.69761501,  -9.53202843,  -8.40409246,
         -7.74981482,  -7.58182532,  -7.50304043,  -7.3157605 ,
         -6.92371409,  -6.88847048,  -6.68440504,  -6.47644573,
         -5.95715282,  -5.68832053,  -5.4880692 ,  -4.82045102,
         -4.16749426,  -3.73144834,  -3.58673752,  -3.04678659,
         -2.71144653,  -2.50379101,  -2.45775569,  -2.42630894,
         -0.14001197,   0.        ,   0.41463161],
       [-27.61783089, -27.30810379, -23.68087732, -22.2882971 ,
        -21.80821618, -20.93835329, -18.70314491, -18.14841586,
        -17.80557372, -17.75128355, -17.25432345, -16.38416421,
        -16.29820044, -15.29644644, -15.26477604, -15.0844108 ,
        -14.89213064, -14.25526169, -13.91411853, -13.35999524,
        -12.55338594, -11.86557929, -11.86384427, -11.72003113,
        -11.59579054, -11.41104791, -11.11289119, -10.9407894 ,
         -9.99343326,  -9.40536257,  -9.36127454,  -9.14074697,
         -8.87325772,  -8.86125988,  -8.80137669,  -8.72649569,
         -8.63029981,  -8.59498563,  -8.55957224,  -8.53264897,
         -8.49808555,  -8.45223335,  -7.51038815,  -7.29139088,
         -6.49862716,  -6.00407362,  -5.94231244,  -5.79717625,
         -5.77548547,  -5.64583209,  -5.48524847,  -4.75364464,
         -4.46638308,  -4.41132517,  -4.28483814,  -3.80439379,
         -2.66157947,  -2.31189269,  -1.79942631,  -0.5694934 ,
          0.        ,   0.1051165 ,   0.16739455,   0.68150524,
          0.68222537,   0.89743917,   0.92877294,   0.99009824,
          1.23541136,   1.26915897,   1.54005804,   2.43022653,
          2.44096625,   2.62818868,   2.67621674,   3.01754407,
          3.2833012 ,   3.28726262,   3.87533965,   4.00776682,
          4.48034071,   4.50812515,   4.78495416,   4.88106824,
          5.05288974,   6.38258002,   6.54816659,   7.67610256,
          8.3303802 ,   8.49836971,   8.57715459,   8.76443453,
          9.15648094,   9.19172454,   9.39578999,   9.6037493 ,
         10.12304221,  10.3918745 ,  10.59212582,  11.25974401,
         11.91270076,  12.34874668,  12.49345751,  13.03340844,
         13.3687485 ,  13.57640402,  13.62243933,  13.65388609,
         15.94018306,  16.08019503,  16.49482663],
       [-10.36350744, -10.05378034,  -6.42655387,  -5.03397364,
         -4.55389273,  -3.68402984,  -1.44882146,  -0.89409241,
         -0.55125026,  -0.4969601 ,   0.        ,   0.87015924,
          0.95612301,   1.95787701,   1.98954741,   2.16991265,
          2.36219281,   2.99906176,   3.34020492,   3.89432821,
          4.70093751,   5.38874416,   5.39047919,   5.53429232,
          5.65853291,   5.84327554,   6.14143227,   6.31353405,
          7.26089019,   7.84896088,   7.89304891,   8.11357648,
          8.38106573,   8.39306357,   8.45294676,   8.52782776,
          8.62402365,   8.65933782,   8.69475121,   8.72167448,
          8.7562379 ,   8.8020901 ,   9.7439353 ,   9.96293257,
         10.75569629,  11.25024983,  11.31201101,  11.4571472 ,
         11.47883799,  11.60849137,  11.76907498,  12.50067881,
         12.78794037,  12.84299828,  12.96948531,  13.44992966,
         14.59274398,  14.94243076,  15.45489714,  16.68483005,
         17.25432345,  17.35943995,  17.421718  ,  17.93582869,
         17.93654882,  18.15176262,  18.18309639,  18.24442169,
         18.48973481,  18.52348242,  18.79438149,  19.68454998,
         19.6952897 ,  19.88251213,  19.93054019,  20.27186752,
         20.53762465,  20.54158607,  21.1296631 ,  21.26209027,
         21.73466416,  21.7624486 ,  22.03927761,  22.1353917 ,
         22.30721319,  23.63690347,  23.80249004,  24.93042601,
         25.58470365,  25.75269316,  25.83147804,  26.01875798,
         26.41080439,  26.446048  ,  26.65011344,  26.85807275,
         27.37736566,  27.64619795,  27.84644928,  28.51406746,
         29.16702422,  29.60307013,  29.74778096,  30.28773189,
         30.62307195,  30.83072747,  30.87676278,  30.90820954,
         33.19450651,  33.33451848,  33.74915009]])

Double predictions clearly wins.

# add -l 1 if nothing shows up
%onnxview onnxgau64

Saves…#

Let’s keep track of it.

with open("gpr_dot_product_boston_32.onnx", "wb") as f:
    f.write(onnxgau32.SerializePartialToString())
from IPython.display import FileLink
FileLink('gpr_dot_product_boston_32.onnx')
gpr_dot_product_boston_32.onnx
with open("gpr_dot_product_boston_64.onnx", "wb") as f:
    f.write(onnxgau64.SerializePartialToString())
FileLink('gpr_dot_product_boston_64.onnx')
gpr_dot_product_boston_64.onnx

Side by side#

We may wonder where the discrepencies start. But for that, we need to do a side by side.

from mlprodict.onnxrt.validate.side_by_side import side_by_side_by_values
sbs = side_by_side_by_values([(oinf32, {'X': X_test.astype(numpy.float32)}),
                              (oinf64, {'X': X_test.astype(numpy.float64)})])

from pandas import DataFrame
df = DataFrame(sbs)
# dfd = df.drop(['value[0]', 'value[1]', 'value[2]'], axis=1).copy()
df
metric step v[0] v[1] cmp name order[0] value[0] shape[0] order[1] value[1] shape[1]
0 nb_results -1 11 1.100000e+01 OK NaN NaN NaN NaN NaN NaN NaN
1 abs-diff 0 0 7.184343e-09 OK X 0.0 [[-0.0018820165, -0.044641636, -0.05147406, -0... (111, 10) 0.0 [[-0.0018820165277906047, -0.04464163650698914... (111, 10)
2 abs-diff 1 0 7.241096e-01 ERROR->=0.7 GPmean 5.0 [[136.0], [146.75], [156.875], [137.625], [143... (111, 1) 5.0 [[136.2904209381668], [147.37000865291338], [1... (111, 1)
3 abs-diff 2 0 7.150779e-09 OK kgpd_MatMulcst -1.0 [[-0.103593096, -0.009147094, 0.016280675, -0.... (10, 331) -1.0 [[-0.10359309315633439, -0.009147093429829445,... (10, 331)
4 abs-diff 3 0 2.693608e-04 e<0.001 kgpd_Addcst -1.0 [23321.936] (1,) -1.0 [23321.93527751423] (1,)
5 abs-diff 4 0 9.174340e-07 OK gpr_MatMulcst -1.0 [-6.7274747, 3.3635502, -4.675215, -7.969895, ... (331,) -1.0 [-6.7274746537081995, 3.363550107698292, -4.67... (331,)
6 abs-diff 5 0 0.000000e+00 OK gpr_Addcst -1.0 [[0.0]] (1, 1) -1.0 [[0.0]] (1, 1)
7 abs-diff 6 0 0.000000e+00 OK Re_Reshapecst -1.0 [-1, 1] (2,) -1.0 [-1, 1] (2,)
8 abs-diff 7 0 7.989149e-09 OK kgpd_Y0 1.0 [[0.013952837, 0.004027498, 0.0033139654, 0.01... (111, 331) 1.0 [[0.013952837286119372, 0.0040274979445440616,... (111, 331)
9 abs-diff 8 0 1.245899e-03 e<0.01 kgpd_C0 2.0 [[23321.95, 23321.94, 23321.94, 23321.953, 233... (111, 331) 2.0 [[23321.949230351514, 23321.939305012173, 2332... (111, 331)
10 abs-diff 9 0 7.241096e-01 ERROR->=0.7 gpr_Y0 3.0 [136.0, 146.75, 156.875, 137.625, 143.6875, 15... (111,) 3.0 [136.2904209381668, 147.37000865291338, 157.17... (111,)
11 abs-diff 10 0 7.241096e-01 ERROR->=0.7 gpr_C0 4.0 [[136.0, 146.75, 156.875, 137.625, 143.6875, 1... (1, 111) 4.0 [[136.2904209381668, 147.37000865291338, 157.1... (1, 111)

The differences really starts for output 'O0' after the matrix multiplication. This matrix melts different number with very different order of magnitudes and that alone explains the discrepencies with doubles and floats on that particular model.

%matplotlib inline
ax = df[['name', 'v[1]']].iloc[1:].set_index('name').plot(kind='bar', figsize=(14,4), logy=True)
ax.set_title("Relative differences for each output between float32 and "
             "float64\nfor a GaussianProcessRegressor");
../_images/onnx_float32_and_64_42_0.png

Before going further, let’s check how sensitive the trained model is about converting double into floats.

pg1 = gau.predict(X_test)
pg2 = gau.predict(X_test.astype(numpy.float32).astype(numpy.float64))
numpy.sort(numpy.sort(numpy.squeeze(pg1 - pg2)))[-5:]
array([2.71829776e-07, 2.75555067e-07, 2.98605300e-07, 3.28873284e-07,
       3.92203219e-07])

Having float or double inputs should not matter. We confirm that with the model converted into ONNX.

p1 = oinf64.run({'X': X_test})['GPmean']
p2 = oinf64.run({'X': X_test.astype(numpy.float32).astype(numpy.float64)})['GPmean']
numpy.sort(numpy.sort(numpy.squeeze(p1 - p2)))[-5:]
array([2.71829776e-07, 2.75555067e-07, 2.98605300e-07, 3.28873284e-07,
       3.92203219e-07])

Last verification.

sbs = side_by_side_by_values([(oinf64, {'X': X_test.astype(numpy.float32).astype(numpy.float64)}),
                              (oinf64, {'X': X_test.astype(numpy.float64)})])
df = DataFrame(sbs)
ax = df[['name', 'v[1]']].iloc[1:].set_index('name').plot(kind='bar', figsize=(14,4), logy=True)
ax.set_title("Relative differences for each output between float64 and float64 rounded to float32"
             "\nfor a GaussianProcessRegressor");
../_images/onnx_float32_and_64_48_0.png