# ONNX graph, single or double floats#

Links: notebook, html, PDF, python, slides, GitHub

The notebook shows discrepencies obtained by using double floats instead of single float in two cases. The second one involves GaussianProcessRegressor.

## Simple case of a linear regression#

A linear regression is simply a matrix multiplication followed by an addition: . Let’s train one with scikit-learn.

from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
X, y = data.data, data.target
X_train, X_test, y_train, y_test = train_test_split(X, y)
clr = LinearRegression()
clr.fit(X_train, y_train)
LinearRegression()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
clr.score(X_test, y_test)
0.48022823853163243
clr.coef_
array([  -3.66884712, -248.12455809,  503.47675603,  314.42722272,
-937.79829646,  589.5139395 ,  166.9937767 ,  238.52080461,
810.51985926,   83.1649252 ])
clr.intercept_
151.72119345267856

Let’s predict with scikit-learn and python.

ypred = clr.predict(X_test)
ypred[:5]
array([ 65.19089869, 136.63206471, 197.78320816,  76.50979441,
120.17048032])
py_pred = X_test @ clr.coef_ + clr.intercept_
py_pred[:5]
array([ 65.19089869, 136.63206471, 197.78320816,  76.50979441,
120.17048032])
clr.coef_.dtype, clr.intercept_.dtype
(dtype('float64'), dtype('float64'))

## With ONNX#

With ONNX, we would write this operation as follows… We still need to convert everything into single floats = float32.

import numpy

numpy.array([clr.intercept_], dtype=numpy.float32),
output_names=['Y'], op_version=12)
onnx_model32 = onnx_fct.to_onnx({'X': X_test.astype(numpy.float32)})

# add -l 1 if nothing shows up
%onnxview onnx_model32

The next line uses a python runtime to compute the prediction.

from mlprodict.onnxrt import OnnxInference
oinf = OnnxInference(onnx_model32, inplace=False)
ort_pred = oinf.run({'X': X_test.astype(numpy.float32)})['Y']
ort_pred[:5]
array([ 65.190895, 136.63206 , 197.7832  ,  76.509796, 120.17048 ],
dtype=float32)

And here is the same with onnxruntime

oinf = OnnxInference(onnx_model32, runtime="onnxruntime1")
ort_pred = oinf.run({'X': X_test.astype(numpy.float32)})['Y']
ort_pred[:5]
array([ 65.190895, 136.63206 , 197.7832  ,  76.509796, 120.17048 ],
dtype=float32)

## With double instead of single float#

ONNX was originally designed for deep learning which usually uses floats but it does not mean cannot be used. Every number is converted into double floats.

numpy.array([clr.intercept_], dtype=numpy.float64),
output_names=['Y'], op_version=12)
onnx_model64 = onnx_fct.to_onnx({'X': X_test.astype(numpy.float64)})

And now the python runtime…

oinf = OnnxInference(onnx_model64)
ort_pred = oinf.run({'X': X_test})['Y']
ort_pred[:5]
array([ 65.19089869, 136.63206471, 197.78320816,  76.50979441,
120.17048032])

And the onnxruntime version of it.

oinf = OnnxInference(onnx_model64, runtime="onnxruntime1")
ort_pred = oinf.run({'X': X_test.astype(numpy.float64)})['Y']
ort_pred[:5]
array([ 65.19089869, 136.63206471, 197.78320816,  76.50979441,
120.17048032])

## And now the GaussianProcessRegressor#

This shows a case

from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import DotProduct
gau = GaussianProcessRegressor(alpha=10, kernel=DotProduct())
gau.fit(X_train, y_train)
GaussianProcessRegressor(alpha=10, kernel=DotProduct(sigma_0=1))
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
from mlprodict.onnx_conv import to_onnx
onnxgau32 = to_onnx(gau, X_train.astype(numpy.float32))
oinf32 = OnnxInference(onnxgau32, runtime="python", inplace=False)
ort_pred32 = oinf32.run({'X': X_test.astype(numpy.float32)})['GPmean']
numpy.squeeze(ort_pred32)[:25]
array([136.    , 146.75  , 156.875 , 137.625 , 143.6875, 157.25  ,
137.625 , 155.4375, 157.125 , 176.1875, 154.    , 144.6875,
152.875 , 163.0625, 134.5   , 169.25  , 143.4375, 156.    ,
147.9375, 147.5625, 143.5625, 139.5   , 167.3125, 162.8125,
157.5   ], dtype=float32)
onnxgau64 = to_onnx(gau, X_train.astype(numpy.float64))
oinf64 = OnnxInference(onnxgau64, runtime="python", inplace=False)
ort_pred64 = oinf64.run({'X': X_test.astype(numpy.float64)})['GPmean']
numpy.squeeze(ort_pred64)[:25]
array([136.29042094, 147.37000865, 157.17181659, 137.37942361,
143.75809938, 157.26946743, 138.0470418 , 155.13779478,
157.13725317, 176.25699851, 154.58148006, 144.76382797,
152.92400576, 162.55328615, 135.01672829, 169.57752091,
144.15882691, 155.9305585 , 147.74172845, 147.95694225,
143.58627788, 139.44744308, 167.34231253, 162.89442931,
157.77991459])

The differences between the predictions for single floats and double floats…

numpy.sort(numpy.sort(numpy.squeeze(ort_pred32 - ort_pred64)))[-5:]
array([0.35428989, 0.37583714, 0.39413358, 0.46870174, 0.50921385])

Who’s right or wrong… The differences between the predictions with the original model…

pred = gau.predict(X_test.astype(numpy.float64))
numpy.sort(numpy.sort(numpy.squeeze(ort_pred32 - pred)))[-5:]
array([[-2.53819985e+01, -2.50722714e+01, -2.14450449e+01,
-2.00524647e+01, -1.95723838e+01, -1.87025209e+01,
-1.64673125e+01, -1.59125835e+01, -1.55697413e+01,
-1.55154512e+01, -1.50184911e+01, -1.41483318e+01,
-1.40623681e+01, -1.30606141e+01, -1.30289437e+01,
-1.28485784e+01, -1.26562983e+01, -1.20194293e+01,
-1.16782862e+01, -1.11241629e+01, -1.03175536e+01,
-9.62974691e+00, -9.62801189e+00, -9.48419875e+00,
-9.35995816e+00, -9.17521553e+00, -8.87705881e+00,
-8.70495702e+00, -7.75760088e+00, -7.16953019e+00,
-7.12544216e+00, -6.90491459e+00, -6.63742534e+00,
-6.62542751e+00, -6.56554431e+00, -6.49066331e+00,
-6.39446743e+00, -6.35915325e+00, -6.32373986e+00,
-6.29681659e+00, -6.26225317e+00, -6.21640097e+00,
-5.27455577e+00, -5.05555850e+00, -4.26279478e+00,
-3.76824125e+00, -3.70648006e+00, -3.56134387e+00,
-3.53965309e+00, -3.40999971e+00, -3.24941609e+00,
-2.51781226e+00, -2.23055070e+00, -2.17549279e+00,
-2.04900576e+00, -1.56856141e+00, -4.25747090e-01,
-7.60603113e-02,  4.36406069e-01,  1.66633897e+00,
2.23583238e+00,  2.34094888e+00,  2.40322693e+00,
2.91733762e+00,  2.91805775e+00,  3.13327155e+00,
3.16460532e+00,  3.22593062e+00,  3.47124374e+00,
3.50499135e+00,  3.77589042e+00,  4.66605891e+00,
4.67679863e+00,  4.86402106e+00,  4.91204912e+00,
5.25337645e+00,  5.51913358e+00,  5.52309500e+00,
6.11117203e+00,  6.24359920e+00,  6.71617309e+00,
6.74395753e+00,  7.02078654e+00,  7.11690062e+00,
7.28872212e+00,  8.61841240e+00,  8.78399897e+00,
9.91193494e+00,  1.05662126e+01,  1.07342021e+01,
1.08129870e+01,  1.10002669e+01,  1.13923133e+01,
1.14275569e+01,  1.16316224e+01,  1.18395817e+01,
1.23588746e+01,  1.26277069e+01,  1.28279582e+01,
1.34955764e+01,  1.41485331e+01,  1.45845791e+01,
1.47292899e+01,  1.52692408e+01,  1.56045809e+01,
1.58122364e+01,  1.58582717e+01,  1.58897185e+01,
1.81760154e+01,  1.83160274e+01,  1.87306590e+01],
[-3.06319985e+01, -3.03222714e+01, -2.66950449e+01,
-2.53024647e+01, -2.48223838e+01, -2.39525209e+01,
-2.17173125e+01, -2.11625835e+01, -2.08197413e+01,
-2.07654512e+01, -2.02684911e+01, -1.93983318e+01,
-1.93123681e+01, -1.83106141e+01, -1.82789437e+01,
-1.80985784e+01, -1.79062983e+01, -1.72694293e+01,
-1.69282862e+01, -1.63741629e+01, -1.55675536e+01,
-1.48797469e+01, -1.48780119e+01, -1.47341988e+01,
-1.46099582e+01, -1.44252155e+01, -1.41270588e+01,
-1.39549570e+01, -1.30076009e+01, -1.24195302e+01,
-1.23754422e+01, -1.21549146e+01, -1.18874253e+01,
-1.18754275e+01, -1.18155443e+01, -1.17406633e+01,
-1.16444674e+01, -1.16091533e+01, -1.15737399e+01,
-1.15468166e+01, -1.15122532e+01, -1.14664010e+01,
-1.05245558e+01, -1.03055585e+01, -9.51279478e+00,
-9.01824125e+00, -8.95648006e+00, -8.81134387e+00,
-8.78965309e+00, -8.65999971e+00, -8.49941609e+00,
-7.76781226e+00, -7.48055070e+00, -7.42549279e+00,
-7.29900576e+00, -6.81856141e+00, -5.67574709e+00,
-5.32606031e+00, -4.81359393e+00, -3.58366103e+00,
-3.01416762e+00, -2.90905112e+00, -2.84677307e+00,
-2.33266238e+00, -2.33194225e+00, -2.11672845e+00,
-2.08539468e+00, -2.02406938e+00, -1.77875626e+00,
-1.74500865e+00, -1.47410958e+00, -5.83941087e-01,
-5.73201374e-01, -3.85978940e-01, -3.37950879e-01,
3.37644562e-03,  2.69133580e-01,  2.73095000e-01,
8.61172031e-01,  9.93599197e-01,  1.46617309e+00,
1.49395753e+00,  1.77078654e+00,  1.86690062e+00,
2.03872212e+00,  3.36841240e+00,  3.53399897e+00,
4.66193494e+00,  5.31621258e+00,  5.48420209e+00,
5.56298697e+00,  5.75026691e+00,  6.14231332e+00,
6.17755692e+00,  6.38162237e+00,  6.58958168e+00,
7.10887459e+00,  7.37770688e+00,  7.57795820e+00,
8.24557639e+00,  8.89853314e+00,  9.33457906e+00,
9.47928989e+00,  1.00192408e+01,  1.03545809e+01,
1.05622364e+01,  1.06082717e+01,  1.06397185e+01,
1.29260154e+01,  1.30660274e+01,  1.34806590e+01],
[-4.40069985e+01, -4.36972714e+01, -4.00700449e+01,
-3.86774647e+01, -3.81973838e+01, -3.73275209e+01,
-3.50923125e+01, -3.45375835e+01, -3.41947413e+01,
-3.41404512e+01, -3.36434911e+01, -3.27733318e+01,
-3.26873681e+01, -3.16856141e+01, -3.16539437e+01,
-3.14735784e+01, -3.12812983e+01, -3.06444293e+01,
-3.03032862e+01, -2.97491629e+01, -2.89425536e+01,
-2.82547469e+01, -2.82530119e+01, -2.81091988e+01,
-2.79849582e+01, -2.78002155e+01, -2.75020588e+01,
-2.73299570e+01, -2.63826009e+01, -2.57945302e+01,
-2.57504422e+01, -2.55299146e+01, -2.52624253e+01,
-2.52504275e+01, -2.51905443e+01, -2.51156633e+01,
-2.50194674e+01, -2.49841533e+01, -2.49487399e+01,
-2.49218166e+01, -2.48872532e+01, -2.48414010e+01,
-2.38995558e+01, -2.36805585e+01, -2.28877948e+01,
-2.23932412e+01, -2.23314801e+01, -2.21863439e+01,
-2.21646531e+01, -2.20349997e+01, -2.18744161e+01,
-2.11428123e+01, -2.08555507e+01, -2.08004928e+01,
-2.06740058e+01, -2.01935614e+01, -1.90507471e+01,
-1.87010603e+01, -1.81885939e+01, -1.69586610e+01,
-1.63891676e+01, -1.62840511e+01, -1.62217731e+01,
-1.57076624e+01, -1.57069422e+01, -1.54917284e+01,
-1.54603947e+01, -1.53990694e+01, -1.51537563e+01,
-1.51200087e+01, -1.48491096e+01, -1.39589411e+01,
-1.39482014e+01, -1.37609789e+01, -1.37129509e+01,
-1.33716236e+01, -1.31058664e+01, -1.31019050e+01,
-1.25138280e+01, -1.23814008e+01, -1.19088269e+01,
-1.18810425e+01, -1.16042135e+01, -1.15080994e+01,
-1.13362779e+01, -1.00065876e+01, -9.84100103e+00,
-8.71306506e+00, -8.05878742e+00, -7.89079791e+00,
-7.81201303e+00, -7.62473309e+00, -7.23268668e+00,
-7.19744308e+00, -6.99337763e+00, -6.78541832e+00,
-6.26612541e+00, -5.99729312e+00, -5.79704180e+00,
-5.12942361e+00, -4.47646686e+00, -4.04042094e+00,
-3.89571011e+00, -3.35575918e+00, -3.02041912e+00,
-2.81276360e+00, -2.76672829e+00, -2.73528153e+00,
-4.48984564e-01, -3.08972594e-01,  1.05659015e-01],
[-2.76319985e+01, -2.73222714e+01, -2.36950449e+01,
-2.23024647e+01, -2.18223838e+01, -2.09525209e+01,
-1.87173125e+01, -1.81625835e+01, -1.78197413e+01,
-1.77654512e+01, -1.72684911e+01, -1.63983318e+01,
-1.63123681e+01, -1.53106141e+01, -1.52789437e+01,
-1.50985784e+01, -1.49062983e+01, -1.42694293e+01,
-1.39282862e+01, -1.33741629e+01, -1.25675536e+01,
-1.18797469e+01, -1.18780119e+01, -1.17341988e+01,
-1.16099582e+01, -1.14252155e+01, -1.11270588e+01,
-1.09549570e+01, -1.00076009e+01, -9.41953019e+00,
-9.37544216e+00, -9.15491459e+00, -8.88742534e+00,
-8.87542751e+00, -8.81554431e+00, -8.74066331e+00,
-8.64446743e+00, -8.60915325e+00, -8.57373986e+00,
-8.54681659e+00, -8.51225317e+00, -8.46640097e+00,
-7.52455577e+00, -7.30555850e+00, -6.51279478e+00,
-6.01824125e+00, -5.95648006e+00, -5.81134387e+00,
-5.78965309e+00, -5.65999971e+00, -5.49941609e+00,
-4.76781226e+00, -4.48055070e+00, -4.42549279e+00,
-4.29900576e+00, -3.81856141e+00, -2.67574709e+00,
-2.32606031e+00, -1.81359393e+00, -5.83661025e-01,
-1.41676202e-02,  9.09488820e-02,  1.53226928e-01,
6.67337617e-01,  6.68057753e-01,  8.83271551e-01,
9.14605317e-01,  9.75930622e-01,  1.22124374e+00,
1.25499135e+00,  1.52589042e+00,  2.41605891e+00,
2.42679863e+00,  2.61402106e+00,  2.66204912e+00,
3.00337645e+00,  3.26913358e+00,  3.27309500e+00,
3.86117203e+00,  3.99359920e+00,  4.46617309e+00,
4.49395753e+00,  4.77078654e+00,  4.86690062e+00,
5.03872212e+00,  6.36841240e+00,  6.53399897e+00,
7.66193494e+00,  8.31621258e+00,  8.48420209e+00,
8.56298697e+00,  8.75026691e+00,  9.14231332e+00,
9.17755692e+00,  9.38162237e+00,  9.58958168e+00,
1.01088746e+01,  1.03777069e+01,  1.05779582e+01,
1.12455764e+01,  1.18985331e+01,  1.23345791e+01,
1.24792899e+01,  1.30192408e+01,  1.33545809e+01,
1.35622364e+01,  1.36082717e+01,  1.36397185e+01,
1.59260154e+01,  1.60660274e+01,  1.64806590e+01],
[-1.05069985e+01, -1.01972714e+01, -6.57004494e+00,
-5.17746472e+00, -4.69738380e+00, -3.82752091e+00,
-1.59231253e+00, -1.03758348e+00, -6.94741336e-01,
-6.40451168e-01, -1.43491072e-01,  7.26668170e-01,
8.12631942e-01,  1.81438594e+00,  1.84605634e+00,
2.02642158e+00,  2.21870174e+00,  2.85557069e+00,
3.19671385e+00,  3.75083714e+00,  4.55744644e+00,
5.24525309e+00,  5.24698811e+00,  5.39080125e+00,
5.51504184e+00,  5.69978447e+00,  5.99794119e+00,
6.17004298e+00,  7.11739912e+00,  7.70546981e+00,
7.74955784e+00,  7.97008541e+00,  8.23757466e+00,
8.24957249e+00,  8.30945569e+00,  8.38433669e+00,
8.48053257e+00,  8.51584675e+00,  8.55126014e+00,
8.57818341e+00,  8.61274683e+00,  8.65859903e+00,
9.60044423e+00,  9.81944150e+00,  1.06122052e+01,
1.11067588e+01,  1.11685199e+01,  1.13136561e+01,
1.13353469e+01,  1.14650003e+01,  1.16255839e+01,
1.23571877e+01,  1.26444493e+01,  1.26995072e+01,
1.28259942e+01,  1.33064386e+01,  1.44492529e+01,
1.47989397e+01,  1.53114061e+01,  1.65413390e+01,
1.71108324e+01,  1.72159489e+01,  1.72782269e+01,
1.77923376e+01,  1.77930578e+01,  1.80082716e+01,
1.80396053e+01,  1.81009306e+01,  1.83462437e+01,
1.83799913e+01,  1.86508904e+01,  1.95410589e+01,
1.95517986e+01,  1.97390211e+01,  1.97870491e+01,
2.01283764e+01,  2.03941336e+01,  2.03980950e+01,
2.09861720e+01,  2.11185992e+01,  2.15911731e+01,
2.16189575e+01,  2.18957865e+01,  2.19919006e+01,
2.21637221e+01,  2.34934124e+01,  2.36589990e+01,
2.47869349e+01,  2.54412126e+01,  2.56092021e+01,
2.56879870e+01,  2.58752669e+01,  2.62673133e+01,
2.63025569e+01,  2.65066224e+01,  2.67145817e+01,
2.72338746e+01,  2.75027069e+01,  2.77029582e+01,
2.83705764e+01,  2.90235331e+01,  2.94595791e+01,
2.96042899e+01,  3.01442408e+01,  3.04795809e+01,
3.06872364e+01,  3.07332717e+01,  3.07647185e+01,
3.30510154e+01,  3.31910274e+01,  3.36056590e+01]])
numpy.sort(numpy.sort(numpy.squeeze(ort_pred64 - pred)))[-5:]
array([[-25.3059382 , -24.9962111 , -21.36898463, -19.9764044 ,
-19.49632349, -18.6264606 , -16.39125222, -15.83652317,
-15.49368102, -15.43939086, -14.94243076, -14.07227152,
-13.98630775, -12.98455375, -12.95288335, -12.77251811,
-12.58023795, -11.943369  , -11.60222584, -11.04810255,
-10.24149325,  -9.5536866 ,  -9.55195157,  -9.40813844,
-9.28389785,  -9.09915522,  -8.80099849,  -8.62889671,
-7.68154057,  -7.09346988,  -7.04938185,  -6.82885428,
-6.56136503,  -6.54936719,  -6.489484  ,  -6.414603  ,
-6.31840711,  -6.28309294,  -6.24767955,  -6.22075628,
-6.18619286,  -6.14034066,  -5.19849546,  -4.97949819,
-4.18673447,  -3.69218093,  -3.63041975,  -3.48528356,
-3.46359277,  -3.33393939,  -3.17335578,  -2.44175195,
-2.15449039,  -2.09943248,  -1.97294545,  -1.4925011 ,
-0.34968678,   0.        ,   0.51246638,   1.74239929,
2.31189269,   2.41700919,   2.47928724,   2.99339793,
2.99411806,   3.20933186,   3.24066563,   3.30199093,
3.54730405,   3.58105166,   3.85195073,   4.74211922,
4.75285894,   4.94008137,   4.98810943,   5.32943676,
5.59519389,   5.59915531,   6.18723234,   6.31965951,
6.7922334 ,   6.82001784,   7.09684685,   7.19296094,
7.36478243,   8.69447271,   8.86005928,   9.98799525,
10.64227289,  10.8102624 ,  10.88904728,  11.07632722,
11.46837363,  11.50361724,  11.70768268,  11.91564199,
12.4349349 ,  12.70376719,  12.90401852,  13.5716367 ,
14.22459346,  14.66063937,  14.8053502 ,  15.34530113,
15.68064119,  15.88829671,  15.93433202,  15.96577878,
18.25207575,  18.39208772,  18.80671933],
[-30.63537495, -30.32564785, -26.69842139, -25.30584116,
-24.82576025, -23.95589736, -21.72068897, -21.16595993,
-20.82311778, -20.76882761, -20.27186752, -19.40170828,
-19.3157445 , -18.31399051, -18.28232011, -18.10195487,
-17.90967471, -17.27280576, -16.9316626 , -16.37753931,
-15.57093001, -14.88312335, -14.88138833, -14.7375752 ,
-14.61333461, -14.42859198, -14.13043525, -13.95833347,
-13.01097732, -12.42290664, -12.37881861, -12.15829103,
-11.89080179, -11.87880395, -11.81892076, -11.74403975,
-11.64784387, -11.6125297 , -11.57711631, -11.55019304,
-11.51562962, -11.46977742, -10.52793221, -10.30893495,
-9.51617123,  -9.02161769,  -8.95985651,  -8.81472032,
-8.79302953,  -8.66337615,  -8.50279254,  -7.77118871,
-7.48392715,  -7.42886923,  -7.3023822 ,  -6.82193786,
-5.67912354,  -5.32943676,  -4.81697038,  -3.58703747,
-3.01754407,  -2.91242756,  -2.85014952,  -2.33603883,
-2.33531869,  -2.1201049 ,  -2.08877113,  -2.02744582,
-1.78213271,  -1.7483851 ,  -1.47748603,  -0.58731753,
-0.57657782,  -0.38935539,  -0.34132732,   0.        ,
0.26575713,   0.26971855,   0.85779559,   0.99022275,
1.46279664,   1.49058108,   1.76741009,   1.86352418,
2.03534568,   3.36503596,   3.53062253,   4.6585585 ,
5.31283614,   5.48082565,   5.55961053,   5.74689046,
6.13893687,   6.17418048,   6.37824592,   6.58620523,
7.10549815,   7.37433043,   7.57458176,   8.24219994,
8.8951567 ,   9.33120262,   9.47591344,  10.01586437,
10.35120443,  10.55885996,  10.60489527,  10.63634202,
12.92263899,  13.06265096,  13.47728257],
[-43.69802592, -43.38829881, -39.76107235, -38.36849212,
-37.88841121, -37.01854832, -34.78333993, -34.22861089,
-33.88576874, -33.83147857, -33.33451848, -32.46435924,
-32.37839546, -31.37664147, -31.34497107, -31.16460583,
-30.97232567, -30.33545672, -29.99431356, -29.44019027,
-28.63358097, -27.94577431, -27.94403929, -27.80022616,
-27.67598557, -27.49124294, -27.19308621, -27.02098443,
-26.07362829, -25.4855576 , -25.44146957, -25.22094199,
-24.95345275, -24.94145491, -24.88157172, -24.80669071,
-24.71049483, -24.67518066, -24.63976727, -24.612844  ,
-24.57828058, -24.53242838, -23.59058318, -23.37158591,
-22.57882219, -22.08426865, -22.02250747, -21.87737128,
-21.85568049, -21.72602711, -21.5654435 , -20.83383967,
-20.54657811, -20.49152019, -20.36503316, -19.88458882,
-18.7417745 , -18.39208772, -17.87962134, -16.64968843,
-16.08019503, -15.97507852, -15.91280048, -15.39868979,
-15.39796965, -15.18275586, -15.15142209, -15.09009678,
-14.84478367, -14.81103606, -14.54013699, -13.64996849,
-13.63922878, -13.45200635, -13.40397829, -13.06265096,
-12.79689383, -12.79293241, -12.20485538, -12.07242821,
-11.59985432, -11.57206988, -11.29524087, -11.19912678,
-11.02730529,  -9.69761501,  -9.53202843,  -8.40409246,
-7.74981482,  -7.58182532,  -7.50304043,  -7.3157605 ,
-6.92371409,  -6.88847048,  -6.68440504,  -6.47644573,
-5.95715282,  -5.68832053,  -5.4880692 ,  -4.82045102,
-4.16749426,  -3.73144834,  -3.58673752,  -3.04678659,
-2.71144653,  -2.50379101,  -2.45775569,  -2.42630894,
-0.14001197,   0.        ,   0.41463161],
[-27.61783089, -27.30810379, -23.68087732, -22.2882971 ,
-21.80821618, -20.93835329, -18.70314491, -18.14841586,
-17.80557372, -17.75128355, -17.25432345, -16.38416421,
-16.29820044, -15.29644644, -15.26477604, -15.0844108 ,
-14.89213064, -14.25526169, -13.91411853, -13.35999524,
-12.55338594, -11.86557929, -11.86384427, -11.72003113,
-11.59579054, -11.41104791, -11.11289119, -10.9407894 ,
-9.99343326,  -9.40536257,  -9.36127454,  -9.14074697,
-8.87325772,  -8.86125988,  -8.80137669,  -8.72649569,
-8.63029981,  -8.59498563,  -8.55957224,  -8.53264897,
-8.49808555,  -8.45223335,  -7.51038815,  -7.29139088,
-6.49862716,  -6.00407362,  -5.94231244,  -5.79717625,
-5.77548547,  -5.64583209,  -5.48524847,  -4.75364464,
-4.46638308,  -4.41132517,  -4.28483814,  -3.80439379,
-2.66157947,  -2.31189269,  -1.79942631,  -0.5694934 ,
0.        ,   0.1051165 ,   0.16739455,   0.68150524,
0.68222537,   0.89743917,   0.92877294,   0.99009824,
1.23541136,   1.26915897,   1.54005804,   2.43022653,
2.44096625,   2.62818868,   2.67621674,   3.01754407,
3.2833012 ,   3.28726262,   3.87533965,   4.00776682,
4.48034071,   4.50812515,   4.78495416,   4.88106824,
5.05288974,   6.38258002,   6.54816659,   7.67610256,
8.3303802 ,   8.49836971,   8.57715459,   8.76443453,
9.15648094,   9.19172454,   9.39578999,   9.6037493 ,
10.12304221,  10.3918745 ,  10.59212582,  11.25974401,
11.91270076,  12.34874668,  12.49345751,  13.03340844,
13.3687485 ,  13.57640402,  13.62243933,  13.65388609,
15.94018306,  16.08019503,  16.49482663],
[-10.36350744, -10.05378034,  -6.42655387,  -5.03397364,
-4.55389273,  -3.68402984,  -1.44882146,  -0.89409241,
-0.55125026,  -0.4969601 ,   0.        ,   0.87015924,
0.95612301,   1.95787701,   1.98954741,   2.16991265,
2.36219281,   2.99906176,   3.34020492,   3.89432821,
4.70093751,   5.38874416,   5.39047919,   5.53429232,
5.65853291,   5.84327554,   6.14143227,   6.31353405,
7.26089019,   7.84896088,   7.89304891,   8.11357648,
8.38106573,   8.39306357,   8.45294676,   8.52782776,
8.62402365,   8.65933782,   8.69475121,   8.72167448,
8.7562379 ,   8.8020901 ,   9.7439353 ,   9.96293257,
10.75569629,  11.25024983,  11.31201101,  11.4571472 ,
11.47883799,  11.60849137,  11.76907498,  12.50067881,
12.78794037,  12.84299828,  12.96948531,  13.44992966,
14.59274398,  14.94243076,  15.45489714,  16.68483005,
17.25432345,  17.35943995,  17.421718  ,  17.93582869,
17.93654882,  18.15176262,  18.18309639,  18.24442169,
18.48973481,  18.52348242,  18.79438149,  19.68454998,
19.6952897 ,  19.88251213,  19.93054019,  20.27186752,
20.53762465,  20.54158607,  21.1296631 ,  21.26209027,
21.73466416,  21.7624486 ,  22.03927761,  22.1353917 ,
22.30721319,  23.63690347,  23.80249004,  24.93042601,
25.58470365,  25.75269316,  25.83147804,  26.01875798,
26.41080439,  26.446048  ,  26.65011344,  26.85807275,
27.37736566,  27.64619795,  27.84644928,  28.51406746,
29.16702422,  29.60307013,  29.74778096,  30.28773189,
30.62307195,  30.83072747,  30.87676278,  30.90820954,
33.19450651,  33.33451848,  33.74915009]])

Double predictions clearly wins.

# add -l 1 if nothing shows up
%onnxview onnxgau64

## Saves…#

Let’s keep track of it.

with open("gpr_dot_product_boston_32.onnx", "wb") as f:
f.write(onnxgau32.SerializePartialToString())
gpr_dot_product_boston_32.onnx
with open("gpr_dot_product_boston_64.onnx", "wb") as f:
f.write(onnxgau64.SerializePartialToString())
gpr_dot_product_boston_64.onnx

## Side by side#

We may wonder where the discrepencies start. But for that, we need to do a side by side.

from mlprodict.onnxrt.validate.side_by_side import side_by_side_by_values
sbs = side_by_side_by_values([(oinf32, {'X': X_test.astype(numpy.float32)}),
(oinf64, {'X': X_test.astype(numpy.float64)})])

from pandas import DataFrame
df = DataFrame(sbs)
# dfd = df.drop(['value[0]', 'value[1]', 'value[2]'], axis=1).copy()
df
metric step v[0] v[1] cmp name order[0] value[0] shape[0] order[1] value[1] shape[1]
0 nb_results -1 11 1.100000e+01 OK NaN NaN NaN NaN NaN NaN NaN
1 abs-diff 0 0 7.184343e-09 OK X 0.0 [[-0.0018820165, -0.044641636, -0.05147406, -0... (111, 10) 0.0 [[-0.0018820165277906047, -0.04464163650698914... (111, 10)
2 abs-diff 1 0 7.241096e-01 ERROR->=0.7 GPmean 5.0 [[136.0], [146.75], [156.875], [137.625], [143... (111, 1) 5.0 [[136.2904209381668], [147.37000865291338], [1... (111, 1)
3 abs-diff 2 0 7.150779e-09 OK kgpd_MatMulcst -1.0 [[-0.103593096, -0.009147094, 0.016280675, -0.... (10, 331) -1.0 [[-0.10359309315633439, -0.009147093429829445,... (10, 331)
4 abs-diff 3 0 2.693608e-04 e<0.001 kgpd_Addcst -1.0 [23321.936] (1,) -1.0 [23321.93527751423] (1,)
5 abs-diff 4 0 9.174340e-07 OK gpr_MatMulcst -1.0 [-6.7274747, 3.3635502, -4.675215, -7.969895, ... (331,) -1.0 [-6.7274746537081995, 3.363550107698292, -4.67... (331,)
6 abs-diff 5 0 0.000000e+00 OK gpr_Addcst -1.0 [[0.0]] (1, 1) -1.0 [[0.0]] (1, 1)
7 abs-diff 6 0 0.000000e+00 OK Re_Reshapecst -1.0 [-1, 1] (2,) -1.0 [-1, 1] (2,)
8 abs-diff 7 0 7.989149e-09 OK kgpd_Y0 1.0 [[0.013952837, 0.004027498, 0.0033139654, 0.01... (111, 331) 1.0 [[0.013952837286119372, 0.0040274979445440616,... (111, 331)
9 abs-diff 8 0 1.245899e-03 e<0.01 kgpd_C0 2.0 [[23321.95, 23321.94, 23321.94, 23321.953, 233... (111, 331) 2.0 [[23321.949230351514, 23321.939305012173, 2332... (111, 331)
10 abs-diff 9 0 7.241096e-01 ERROR->=0.7 gpr_Y0 3.0 [136.0, 146.75, 156.875, 137.625, 143.6875, 15... (111,) 3.0 [136.2904209381668, 147.37000865291338, 157.17... (111,)
11 abs-diff 10 0 7.241096e-01 ERROR->=0.7 gpr_C0 4.0 [[136.0, 146.75, 156.875, 137.625, 143.6875, 1... (1, 111) 4.0 [[136.2904209381668, 147.37000865291338, 157.1... (1, 111)

The differences really starts for output 'O0' after the matrix multiplication. This matrix melts different number with very different order of magnitudes and that alone explains the discrepencies with doubles and floats on that particular model.

%matplotlib inline
ax = df[['name', 'v[1]']].iloc[1:].set_index('name').plot(kind='bar', figsize=(14,4), logy=True)
ax.set_title("Relative differences for each output between float32 and "
"float64\nfor a GaussianProcessRegressor");

Before going further, let’s check how sensitive the trained model is about converting double into floats.

pg1 = gau.predict(X_test)
pg2 = gau.predict(X_test.astype(numpy.float32).astype(numpy.float64))
numpy.sort(numpy.sort(numpy.squeeze(pg1 - pg2)))[-5:]
array([2.71829776e-07, 2.75555067e-07, 2.98605300e-07, 3.28873284e-07,
3.92203219e-07])

Having float or double inputs should not matter. We confirm that with the model converted into ONNX.

p1 = oinf64.run({'X': X_test})['GPmean']
p2 = oinf64.run({'X': X_test.astype(numpy.float32).astype(numpy.float64)})['GPmean']
numpy.sort(numpy.sort(numpy.squeeze(p1 - p2)))[-5:]
array([2.71829776e-07, 2.75555067e-07, 2.98605300e-07, 3.28873284e-07,
3.92203219e-07])

Last verification.

sbs = side_by_side_by_values([(oinf64, {'X': X_test.astype(numpy.float32).astype(numpy.float64)}),
(oinf64, {'X': X_test.astype(numpy.float64)})])
df = DataFrame(sbs)
ax = df[['name', 'v[1]']].iloc[1:].set_index('name').plot(kind='bar', figsize=(14,4), logy=True)
ax.set_title("Relative differences for each output between float64 and float64 rounded to float32"
"\nfor a GaussianProcessRegressor");