# ONNX graph, single or double floats#

The notebook shows discrepencies obtained by using double floats instead of single float in two cases. The second one involves GaussianProcessRegressor.

## Simple case of a linear regression#

A linear regression is simply a matrix multiplication followed by an addition: . Let’s train one with scikit-learn.

from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
X, y = data.data, data.target
X_train, X_test, y_train, y_test = train_test_split(X, y)
clr = LinearRegression()
clr.fit(X_train, y_train)
LinearRegression()
clr.score(X_test, y_test)
0.48022823853163243
clr.coef_
array([  -3.66884712, -248.12455809,  503.47675603,  314.42722272,
-937.79829646,  589.5139395 ,  166.9937767 ,  238.52080461,
810.51985926,   83.1649252 ])
clr.intercept_
151.72119345267856

Let’s predict with scikit-learn and python.

ypred = clr.predict(X_test)
ypred[:5]
array([ 65.19089869, 136.63206471, 197.78320816,  76.50979441,
120.17048032])
py_pred = X_test @ clr.coef_ + clr.intercept_
py_pred[:5]
array([ 65.19089869, 136.63206471, 197.78320816,  76.50979441,
120.17048032])
clr.coef_.dtype, clr.intercept_.dtype
(dtype('float64'), dtype('float64'))

## With ONNX#

With ONNX, we would write this operation as follows… We still need to convert everything into single floats = float32.

import numpy

numpy.array([clr.intercept_], dtype=numpy.float32),
output_names=['Y'], op_version=12)
onnx_model32 = onnx_fct.to_onnx({'X': X_test.astype(numpy.float32)})

# add -l 1 if nothing shows up
%onnxview onnx_model32

The next line uses a python runtime to compute the prediction.

from mlprodict.onnxrt import OnnxInference
oinf = OnnxInference(onnx_model32, inplace=False)
ort_pred = oinf.run({'X': X_test.astype(numpy.float32)})['Y']
ort_pred[:5]
array([ 65.190895, 136.63206 , 197.7832  ,  76.509796, 120.17048 ],
dtype=float32)

And here is the same with onnxruntime

oinf = OnnxInference(onnx_model32, runtime="onnxruntime1")
ort_pred = oinf.run({'X': X_test.astype(numpy.float32)})['Y']
ort_pred[:5]
array([ 65.190895, 136.63206 , 197.7832  ,  76.509796, 120.17048 ],
dtype=float32)

## With double instead of single float#

ONNX was originally designed for deep learning which usually uses floats but it does not mean cannot be used. Every number is converted into double floats.

numpy.array([clr.intercept_], dtype=numpy.float64),
output_names=['Y'], op_version=12)
onnx_model64 = onnx_fct.to_onnx({'X': X_test.astype(numpy.float64)})

And now the python runtime…

oinf = OnnxInference(onnx_model64)
ort_pred = oinf.run({'X': X_test})['Y']
ort_pred[:5]
array([ 65.19089869, 136.63206471, 197.78320816,  76.50979441,
120.17048032])

And the onnxruntime version of it.

oinf = OnnxInference(onnx_model64, runtime="onnxruntime1")
ort_pred = oinf.run({'X': X_test.astype(numpy.float64)})['Y']
ort_pred[:5]
array([ 65.19089869, 136.63206471, 197.78320816,  76.50979441,
120.17048032])

## And now the GaussianProcessRegressor#

This shows a case

from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import DotProduct
gau = GaussianProcessRegressor(alpha=10, kernel=DotProduct())
gau.fit(X_train, y_train)
GaussianProcessRegressor(alpha=10, kernel=DotProduct(sigma_0=1))
from mlprodict.onnx_conv import to_onnx
onnxgau32 = to_onnx(gau, X_train.astype(numpy.float32))
oinf32 = OnnxInference(onnxgau32, runtime="python", inplace=False)
ort_pred32 = oinf32.run({'X': X_test.astype(numpy.float32)})['GPmean']
numpy.squeeze(ort_pred32)[:25]
array([136.    , 146.75  , 156.875 , 137.625 , 143.6875, 157.25  ,
137.625 , 155.4375, 157.125 , 176.1875, 154.    , 144.6875,
152.875 , 163.0625, 134.5   , 169.25  , 143.4375, 156.    ,
147.9375, 147.5625, 143.5625, 139.5   , 167.3125, 162.8125,
157.5   ], dtype=float32)
onnxgau64 = to_onnx(gau, X_train.astype(numpy.float64))
oinf64 = OnnxInference(onnxgau64, runtime="python", inplace=False)
ort_pred64 = oinf64.run({'X': X_test.astype(numpy.float64)})['GPmean']
numpy.squeeze(ort_pred64)[:25]
array([136.29042094, 147.37000865, 157.17181659, 137.37942361,
143.75809938, 157.26946743, 138.0470418 , 155.13779478,
157.13725317, 176.25699851, 154.58148006, 144.76382797,
152.92400576, 162.55328615, 135.01672829, 169.57752091,
144.15882691, 155.9305585 , 147.74172845, 147.95694225,
143.58627788, 139.44744308, 167.34231253, 162.89442931,
157.77991459])

The differences between the predictions for single floats and double floats…

numpy.sort(numpy.sort(numpy.squeeze(ort_pred32 - ort_pred64)))[-5:]
array([0.35428989, 0.37583714, 0.39413358, 0.46870174, 0.50921385])

Who’s right or wrong… The differences between the predictions with the original model…

pred = gau.predict(X_test.astype(numpy.float64))
numpy.sort(numpy.sort(numpy.squeeze(ort_pred32 - pred)))[-5:]
array([[-2.53819985e+01, -2.50722714e+01, -2.14450449e+01,
-2.00524647e+01, -1.95723838e+01, -1.87025209e+01,
-1.64673125e+01, -1.59125835e+01, -1.55697413e+01,
-1.55154512e+01, -1.50184911e+01, -1.41483318e+01,
-1.40623681e+01, -1.30606141e+01, -1.30289437e+01,
-1.28485784e+01, -1.26562983e+01, -1.20194293e+01,
-1.16782862e+01, -1.11241629e+01, -1.03175536e+01,
-9.62974691e+00, -9.62801189e+00, -9.48419875e+00,
-9.35995816e+00, -9.17521553e+00, -8.87705881e+00,
-8.70495702e+00, -7.75760088e+00, -7.16953019e+00,
-7.12544216e+00, -6.90491459e+00, -6.63742534e+00,
-6.62542751e+00, -6.56554431e+00, -6.49066331e+00,
-6.39446743e+00, -6.35915325e+00, -6.32373986e+00,
-6.29681659e+00, -6.26225317e+00, -6.21640097e+00,
-5.27455577e+00, -5.05555850e+00, -4.26279478e+00,
-3.76824125e+00, -3.70648006e+00, -3.56134387e+00,
-3.53965309e+00, -3.40999971e+00, -3.24941609e+00,
-2.51781226e+00, -2.23055070e+00, -2.17549279e+00,
-2.04900576e+00, -1.56856141e+00, -4.25747090e-01,
-7.60603113e-02,  4.36406069e-01,  1.66633897e+00,
2.23583238e+00,  2.34094888e+00,  2.40322693e+00,
2.91733762e+00,  2.91805775e+00,  3.13327155e+00,
3.16460532e+00,  3.22593062e+00,  3.47124374e+00,
3.50499135e+00,  3.77589042e+00,  4.66605891e+00,
4.67679863e+00,  4.86402106e+00,  4.91204912e+00,
5.25337645e+00,  5.51913358e+00,  5.52309500e+00,
6.11117203e+00,  6.24359920e+00,  6.71617309e+00,
6.74395753e+00,  7.02078654e+00,  7.11690062e+00,
7.28872212e+00,  8.61841240e+00,  8.78399897e+00,
9.91193494e+00,  1.05662126e+01,  1.07342021e+01,
1.08129870e+01,  1.10002669e+01,  1.13923133e+01,
1.14275569e+01,  1.16316224e+01,  1.18395817e+01,
1.23588746e+01,  1.26277069e+01,  1.28279582e+01,
1.34955764e+01,  1.41485331e+01,  1.45845791e+01,
1.47292899e+01,  1.52692408e+01,  1.56045809e+01,
1.58122364e+01,  1.58582717e+01,  1.58897185e+01,
1.81760154e+01,  1.83160274e+01,  1.87306590e+01],
[-3.06319985e+01, -3.03222714e+01, -2.66950449e+01,
-2.53024647e+01, -2.48223838e+01, -2.39525209e+01,
-2.17173125e+01, -2.11625835e+01, -2.08197413e+01,
-2.07654512e+01, -2.02684911e+01, -1.93983318e+01,
-1.93123681e+01, -1.83106141e+01, -1.82789437e+01,
-1.80985784e+01, -1.79062983e+01, -1.72694293e+01,
-1.69282862e+01, -1.63741629e+01, -1.55675536e+01,
-1.48797469e+01, -1.48780119e+01, -1.47341988e+01,
-1.46099582e+01, -1.44252155e+01, -1.41270588e+01,
-1.39549570e+01, -1.30076009e+01, -1.24195302e+01,
-1.23754422e+01, -1.21549146e+01, -1.18874253e+01,
-1.18754275e+01, -1.18155443e+01, -1.17406633e+01,
-1.16444674e+01, -1.16091533e+01, -1.15737399e+01,
-1.15468166e+01, -1.15122532e+01, -1.14664010e+01,
-1.05245558e+01, -1.03055585e+01, -9.51279478e+00,
-9.01824125e+00, -8.95648006e+00, -8.81134387e+00,
-8.78965309e+00, -8.65999971e+00, -8.49941609e+00,
-7.76781226e+00, -7.48055070e+00, -7.42549279e+00,
-7.29900576e+00, -6.81856141e+00, -5.67574709e+00,
-5.32606031e+00, -4.81359393e+00, -3.58366103e+00,
-3.01416762e+00, -2.90905112e+00, -2.84677307e+00,
-2.33266238e+00, -2.33194225e+00, -2.11672845e+00,
-2.08539468e+00, -2.02406938e+00, -1.77875626e+00,
-1.74500865e+00, -1.47410958e+00, -5.83941087e-01,
-5.73201374e-01, -3.85978940e-01, -3.37950879e-01,
3.37644562e-03,  2.69133580e-01,  2.73095000e-01,
8.61172031e-01,  9.93599197e-01,  1.46617309e+00,
1.49395753e+00,  1.77078654e+00,  1.86690062e+00,
2.03872212e+00,  3.36841240e+00,  3.53399897e+00,
4.66193494e+00,  5.31621258e+00,  5.48420209e+00,
5.56298697e+00,  5.75026691e+00,  6.14231332e+00,
6.17755692e+00,  6.38162237e+00,  6.58958168e+00,
7.10887459e+00,  7.37770688e+00,  7.57795820e+00,
8.24557639e+00,  8.89853314e+00,  9.33457906e+00,
9.47928989e+00,  1.00192408e+01,  1.03545809e+01,
1.05622364e+01,  1.06082717e+01,  1.06397185e+01,
1.29260154e+01,  1.30660274e+01,  1.34806590e+01],
[-4.40069985e+01, -4.36972714e+01, -4.00700449e+01,
-3.86774647e+01, -3.81973838e+01, -3.73275209e+01,
-3.50923125e+01, -3.45375835e+01, -3.41947413e+01,
-3.41404512e+01, -3.36434911e+01, -3.27733318e+01,
-3.26873681e+01, -3.16856141e+01, -3.16539437e+01,
-3.14735784e+01, -3.12812983e+01, -3.06444293e+01,
-3.03032862e+01, -2.97491629e+01, -2.89425536e+01,
-2.82547469e+01, -2.82530119e+01, -2.81091988e+01,
-2.79849582e+01, -2.78002155e+01, -2.75020588e+01,
-2.73299570e+01, -2.63826009e+01, -2.57945302e+01,
-2.57504422e+01, -2.55299146e+01, -2.52624253e+01,
-2.52504275e+01, -2.51905443e+01, -2.51156633e+01,
-2.50194674e+01, -2.49841533e+01, -2.49487399e+01,
-2.49218166e+01, -2.48872532e+01, -2.48414010e+01,
-2.38995558e+01, -2.36805585e+01, -2.28877948e+01,
-2.23932412e+01, -2.23314801e+01, -2.21863439e+01,
-2.21646531e+01, -2.20349997e+01, -2.18744161e+01,
-2.11428123e+01, -2.08555507e+01, -2.08004928e+01,
-2.06740058e+01, -2.01935614e+01, -1.90507471e+01,
-1.87010603e+01, -1.81885939e+01, -1.69586610e+01,
-1.63891676e+01, -1.62840511e+01, -1.62217731e+01,
-1.57076624e+01, -1.57069422e+01, -1.54917284e+01,
-1.54603947e+01, -1.53990694e+01, -1.51537563e+01,
-1.51200087e+01, -1.48491096e+01, -1.39589411e+01,
-1.39482014e+01, -1.37609789e+01, -1.37129509e+01,
-1.33716236e+01, -1.31058664e+01, -1.31019050e+01,
-1.25138280e+01, -1.23814008e+01, -1.19088269e+01,
-1.18810425e+01, -1.16042135e+01, -1.15080994e+01,
-1.13362779e+01, -1.00065876e+01, -9.84100103e+00,
-8.71306506e+00, -8.05878742e+00, -7.89079791e+00,
-7.81201303e+00, -7.62473309e+00, -7.23268668e+00,
-7.19744308e+00, -6.99337763e+00, -6.78541832e+00,
-6.26612541e+00, -5.99729312e+00, -5.79704180e+00,
-5.12942361e+00, -4.47646686e+00, -4.04042094e+00,
-3.89571011e+00, -3.35575918e+00, -3.02041912e+00,
-2.81276360e+00, -2.76672829e+00, -2.73528153e+00,
-4.48984564e-01, -3.08972594e-01,  1.05659015e-01],
[-2.76319985e+01, -2.73222714e+01, -2.36950449e+01,
-2.23024647e+01, -2.18223838e+01, -2.09525209e+01,
-1.87173125e+01, -1.81625835e+01, -1.78197413e+01,
-1.77654512e+01, -1.72684911e+01, -1.63983318e+01,
-1.63123681e+01, -1.53106141e+01, -1.52789437e+01,
-1.50985784e+01, -1.49062983e+01, -1.42694293e+01,
-1.39282862e+01, -1.33741629e+01, -1.25675536e+01,
-1.18797469e+01, -1.18780119e+01, -1.17341988e+01,
-1.16099582e+01, -1.14252155e+01, -1.11270588e+01,
-1.09549570e+01, -1.00076009e+01, -9.41953019e+00,
-9.37544216e+00, -9.15491459e+00, -8.88742534e+00,
-8.87542751e+00, -8.81554431e+00, -8.74066331e+00,
-8.64446743e+00, -8.60915325e+00, -8.57373986e+00,
-8.54681659e+00, -8.51225317e+00, -8.46640097e+00,
-7.52455577e+00, -7.30555850e+00, -6.51279478e+00,
-6.01824125e+00, -5.95648006e+00, -5.81134387e+00,
-5.78965309e+00, -5.65999971e+00, -5.49941609e+00,
-4.76781226e+00, -4.48055070e+00, -4.42549279e+00,
-4.29900576e+00, -3.81856141e+00, -2.67574709e+00,
-2.32606031e+00, -1.81359393e+00, -5.83661025e-01,
-1.41676202e-02,  9.09488820e-02,  1.53226928e-01,
6.67337617e-01,  6.68057753e-01,  8.83271551e-01,
9.14605317e-01,  9.75930622e-01,  1.22124374e+00,
1.25499135e+00,  1.52589042e+00,  2.41605891e+00,
2.42679863e+00,  2.61402106e+00,  2.66204912e+00,
3.00337645e+00,  3.26913358e+00,  3.27309500e+00,
3.86117203e+00,  3.99359920e+00,  4.46617309e+00,
4.49395753e+00,  4.77078654e+00,  4.86690062e+00,
5.03872212e+00,  6.36841240e+00,  6.53399897e+00,
7.66193494e+00,  8.31621258e+00,  8.48420209e+00,
8.56298697e+00,  8.75026691e+00,  9.14231332e+00,
9.17755692e+00,  9.38162237e+00,  9.58958168e+00,
1.01088746e+01,  1.03777069e+01,  1.05779582e+01,
1.12455764e+01,  1.18985331e+01,  1.23345791e+01,
1.24792899e+01,  1.30192408e+01,  1.33545809e+01,
1.35622364e+01,  1.36082717e+01,  1.36397185e+01,
1.59260154e+01,  1.60660274e+01,  1.64806590e+01],
[-1.05069985e+01, -1.01972714e+01, -6.57004494e+00,
-5.17746472e+00, -4.69738380e+00, -3.82752091e+00,
-1.59231253e+00, -1.03758348e+00, -6.94741336e-01,
-6.40451168e-01, -1.43491072e-01,  7.26668170e-01,
8.12631942e-01,  1.81438594e+00,  1.84605634e+00,
2.02642158e+00,  2.21870174e+00,  2.85557069e+00,
3.19671385e+00,  3.75083714e+00,  4.55744644e+00,
5.24525309e+00,  5.24698811e+00,  5.39080125e+00,
5.51504184e+00,  5.69978447e+00,  5.99794119e+00,
6.17004298e+00,  7.11739912e+00,  7.70546981e+00,
7.74955784e+00,  7.97008541e+00,  8.23757466e+00,
8.24957249e+00,  8.30945569e+00,  8.38433669e+00,
8.48053257e+00,  8.51584675e+00,  8.55126014e+00,
8.57818341e+00,  8.61274683e+00,  8.65859903e+00,
9.60044423e+00,  9.81944150e+00,  1.06122052e+01,
1.11067588e+01,  1.11685199e+01,  1.13136561e+01,
1.13353469e+01,  1.14650003e+01,  1.16255839e+01,
1.23571877e+01,  1.26444493e+01,  1.26995072e+01,
1.28259942e+01,  1.33064386e+01,  1.44492529e+01,
1.47989397e+01,  1.53114061e+01,  1.65413390e+01,
1.71108324e+01,  1.72159489e+01,  1.72782269e+01,
1.77923376e+01,  1.77930578e+01,  1.80082716e+01,
1.80396053e+01,  1.81009306e+01,  1.83462437e+01,
1.83799913e+01,  1.86508904e+01,  1.95410589e+01,
1.95517986e+01,  1.97390211e+01,  1.97870491e+01,
2.01283764e+01,  2.03941336e+01,  2.03980950e+01,
2.09861720e+01,  2.11185992e+01,  2.15911731e+01,
2.16189575e+01,  2.18957865e+01,  2.19919006e+01,
2.21637221e+01,  2.34934124e+01,  2.36589990e+01,
2.47869349e+01,  2.54412126e+01,  2.56092021e+01,
2.56879870e+01,  2.58752669e+01,  2.62673133e+01,
2.63025569e+01,  2.65066224e+01,  2.67145817e+01,
2.72338746e+01,  2.75027069e+01,  2.77029582e+01,
2.83705764e+01,  2.90235331e+01,  2.94595791e+01,
2.96042899e+01,  3.01442408e+01,  3.04795809e+01,
3.06872364e+01,  3.07332717e+01,  3.07647185e+01,
3.30510154e+01,  3.31910274e+01,  3.36056590e+01]])
numpy.sort(numpy.sort(numpy.squeeze(ort_pred64 - pred)))[-5:]
array([[-25.3059382 , -24.9962111 , -21.36898463, -19.9764044 ,
-19.49632349, -18.6264606 , -16.39125222, -15.83652317,
-15.49368102, -15.43939086, -14.94243076, -14.07227152,
-13.98630775, -12.98455375, -12.95288335, -12.77251811,
-12.58023795, -11.943369  , -11.60222584, -11.04810255,
-10.24149325,  -9.5536866 ,  -9.55195157,  -9.40813844,
-9.28389785,  -9.09915522,  -8.80099849,  -8.62889671,
-7.68154057,  -7.09346988,  -7.04938185,  -6.82885428,
-6.56136503,  -6.54936719,  -6.489484  ,  -6.414603  ,
-6.31840711,  -6.28309294,  -6.24767955,  -6.22075628,
-6.18619286,  -6.14034066,  -5.19849546,  -4.97949819,
-4.18673447,  -3.69218093,  -3.63041975,  -3.48528356,
-3.46359277,  -3.33393939,  -3.17335578,  -2.44175195,
-2.15449039,  -2.09943248,  -1.97294545,  -1.4925011 ,
-0.34968678,   0.        ,   0.51246638,   1.74239929,
2.31189269,   2.41700919,   2.47928724,   2.99339793,
2.99411806,   3.20933186,   3.24066563,   3.30199093,
3.54730405,   3.58105166,   3.85195073,   4.74211922,
4.75285894,   4.94008137,   4.98810943,   5.32943676,
5.59519389,   5.59915531,   6.18723234,   6.31965951,
6.7922334 ,   6.82001784,   7.09684685,   7.19296094,
7.36478243,   8.69447271,   8.86005928,   9.98799525,
10.64227289,  10.8102624 ,  10.88904728,  11.07632722,
11.46837363,  11.50361724,  11.70768268,  11.91564199,
12.4349349 ,  12.70376719,  12.90401852,  13.5716367 ,
14.22459346,  14.66063937,  14.8053502 ,  15.34530113,
15.68064119,  15.88829671,  15.93433202,  15.96577878,
18.25207575,  18.39208772,  18.80671933],
[-30.63537495, -30.32564785, -26.69842139, -25.30584116,
-24.82576025, -23.95589736, -21.72068897, -21.16595993,
-20.82311778, -20.76882761, -20.27186752, -19.40170828,
-19.3157445 , -18.31399051, -18.28232011, -18.10195487,
-17.90967471, -17.27280576, -16.9316626 , -16.37753931,
-15.57093001, -14.88312335, -14.88138833, -14.7375752 ,
-14.61333461, -14.42859198, -14.13043525, -13.95833347,
-13.01097732, -12.42290664, -12.37881861, -12.15829103,
-11.89080179, -11.87880395, -11.81892076, -11.74403975,
-11.64784387, -11.6125297 , -11.57711631, -11.55019304,
-11.51562962, -11.46977742, -10.52793221, -10.30893495,
-9.51617123,  -9.02161769,  -8.95985651,  -8.81472032,
-8.79302953,  -8.66337615,  -8.50279254,  -7.77118871,
-7.48392715,  -7.42886923,  -7.3023822 ,  -6.82193786,
-5.67912354,  -5.32943676,  -4.81697038,  -3.58703747,
-3.01754407,  -2.91242756,  -2.85014952,  -2.33603883,
-2.33531869,  -2.1201049 ,  -2.08877113,  -2.02744582,
-1.78213271,  -1.7483851 ,  -1.47748603,  -0.58731753,
-0.57657782,  -0.38935539,  -0.34132732,   0.        ,
0.26575713,   0.26971855,   0.85779559,   0.99022275,
1.46279664,   1.49058108,   1.76741009,   1.86352418,
2.03534568,   3.36503596,   3.53062253,   4.6585585 ,
5.31283614,   5.48082565,   5.55961053,   5.74689046,
6.13893687,   6.17418048,   6.37824592,   6.58620523,
7.10549815,   7.37433043,   7.57458176,   8.24219994,
8.8951567 ,   9.33120262,   9.47591344,  10.01586437,
10.35120443,  10.55885996,  10.60489527,  10.63634202,
12.92263899,  13.06265096,  13.47728257],
[-43.69802592, -43.38829881, -39.76107235, -38.36849212,
-37.88841121, -37.01854832, -34.78333993, -34.22861089,
-33.88576874, -33.83147857, -33.33451848, -32.46435924,
-32.37839546, -31.37664147, -31.34497107, -31.16460583,
-30.97232567, -30.33545672, -29.99431356, -29.44019027,
-28.63358097, -27.94577431, -27.94403929, -27.80022616,
-27.67598557, -27.49124294, -27.19308621, -27.02098443,
-26.07362829, -25.4855576 , -25.44146957, -25.22094199,
-24.95345275, -24.94145491, -24.88157172, -24.80669071,
-24.71049483, -24.67518066, -24.63976727, -24.612844  ,
-24.57828058, -24.53242838, -23.59058318, -23.37158591,
-22.57882219, -22.08426865, -22.02250747, -21.87737128,
-21.85568049, -21.72602711, -21.5654435 , -20.83383967,
-20.54657811, -20.49152019, -20.36503316, -19.88458882,
-18.7417745 , -18.39208772, -17.87962134, -16.64968843,
-16.08019503, -15.97507852, -15.91280048, -15.39868979,
-15.39796965, -15.18275586, -15.15142209, -15.09009678,
-14.84478367, -14.81103606, -14.54013699, -13.64996849,
-13.63922878, -13.45200635, -13.40397829, -13.06265096,
-12.79689383, -12.79293241, -12.20485538, -12.07242821,
-11.59985432, -11.57206988, -11.29524087, -11.19912678,
-11.02730529,  -9.69761501,  -9.53202843,  -8.40409246,
-7.74981482,  -7.58182532,  -7.50304043,  -7.3157605 ,
-6.92371409,  -6.88847048,  -6.68440504,  -6.47644573,
-5.95715282,  -5.68832053,  -5.4880692 ,  -4.82045102,
-4.16749426,  -3.73144834,  -3.58673752,  -3.04678659,
-2.71144653,  -2.50379101,  -2.45775569,  -2.42630894,
-0.14001197,   0.        ,   0.41463161],
[-27.61783089, -27.30810379, -23.68087732, -22.2882971 ,
-21.80821618, -20.93835329, -18.70314491, -18.14841586,
-17.80557372, -17.75128355, -17.25432345, -16.38416421,
-16.29820044, -15.29644644, -15.26477604, -15.0844108 ,
-14.89213064, -14.25526169, -13.91411853, -13.35999524,
-12.55338594, -11.86557929, -11.86384427, -11.72003113,
-11.59579054, -11.41104791, -11.11289119, -10.9407894 ,
-9.99343326,  -9.40536257,  -9.36127454,  -9.14074697,
-8.87325772,  -8.86125988,  -8.80137669,  -8.72649569,
-8.63029981,  -8.59498563,  -8.55957224,  -8.53264897,
-8.49808555,  -8.45223335,  -7.51038815,  -7.29139088,
-6.49862716,  -6.00407362,  -5.94231244,  -5.79717625,
-5.77548547,  -5.64583209,  -5.48524847,  -4.75364464,
-4.46638308,  -4.41132517,  -4.28483814,  -3.80439379,
-2.66157947,  -2.31189269,  -1.79942631,  -0.5694934 ,
0.        ,   0.1051165 ,   0.16739455,   0.68150524,
0.68222537,   0.89743917,   0.92877294,   0.99009824,
1.23541136,   1.26915897,   1.54005804,   2.43022653,
2.44096625,   2.62818868,   2.67621674,   3.01754407,
3.2833012 ,   3.28726262,   3.87533965,   4.00776682,
4.48034071,   4.50812515,   4.78495416,   4.88106824,
5.05288974,   6.38258002,   6.54816659,   7.67610256,
8.3303802 ,   8.49836971,   8.57715459,   8.76443453,
9.15648094,   9.19172454,   9.39578999,   9.6037493 ,
10.12304221,  10.3918745 ,  10.59212582,  11.25974401,
11.91270076,  12.34874668,  12.49345751,  13.03340844,
13.3687485 ,  13.57640402,  13.62243933,  13.65388609,
15.94018306,  16.08019503,  16.49482663],
[-10.36350744, -10.05378034,  -6.42655387,  -5.03397364,
-4.55389273,  -3.68402984,  -1.44882146,  -0.89409241,
-0.55125026,  -0.4969601 ,   0.        ,   0.87015924,
0.95612301,   1.95787701,   1.98954741,   2.16991265,
2.36219281,   2.99906176,   3.34020492,   3.89432821,
4.70093751,   5.38874416,   5.39047919,   5.53429232,
5.65853291,   5.84327554,   6.14143227,   6.31353405,
7.26089019,   7.84896088,   7.89304891,   8.11357648,
8.38106573,   8.39306357,   8.45294676,   8.52782776,
8.62402365,   8.65933782,   8.69475121,   8.72167448,
8.7562379 ,   8.8020901 ,   9.7439353 ,   9.96293257,
10.75569629,  11.25024983,  11.31201101,  11.4571472 ,
11.47883799,  11.60849137,  11.76907498,  12.50067881,
12.78794037,  12.84299828,  12.96948531,  13.44992966,
14.59274398,  14.94243076,  15.45489714,  16.68483005,
17.25432345,  17.35943995,  17.421718  ,  17.93582869,
17.93654882,  18.15176262,  18.18309639,  18.24442169,
18.48973481,  18.52348242,  18.79438149,  19.68454998,
19.6952897 ,  19.88251213,  19.93054019,  20.27186752,
20.53762465,  20.54158607,  21.1296631 ,  21.26209027,
21.73466416,  21.7624486 ,  22.03927761,  22.1353917 ,
22.30721319,  23.63690347,  23.80249004,  24.93042601,
25.58470365,  25.75269316,  25.83147804,  26.01875798,
26.41080439,  26.446048  ,  26.65011344,  26.85807275,
27.37736566,  27.64619795,  27.84644928,  28.51406746,
29.16702422,  29.60307013,  29.74778096,  30.28773189,
30.62307195,  30.83072747,  30.87676278,  30.90820954,
33.19450651,  33.33451848,  33.74915009]])

Double predictions clearly wins.

# add -l 1 if nothing shows up
%onnxview onnxgau64

## Saves…#

Let’s keep track of it.

with open("gpr_dot_product_boston_32.onnx", "wb") as f:
f.write(onnxgau32.SerializePartialToString())
gpr_dot_product_boston_32.onnx
with open("gpr_dot_product_boston_64.onnx", "wb") as f:
f.write(onnxgau64.SerializePartialToString())
gpr_dot_product_boston_64.onnx

## Side by side#

We may wonder where the discrepencies start. But for that, we need to do a side by side.

from mlprodict.onnxrt.validate.side_by_side import side_by_side_by_values
sbs = side_by_side_by_values([(oinf32, {'X': X_test.astype(numpy.float32)}),
(oinf64, {'X': X_test.astype(numpy.float64)})])

from pandas import DataFrame
df = DataFrame(sbs)
# dfd = df.drop(['value[0]', 'value[1]', 'value[2]'], axis=1).copy()
df
metric step v[0] v[1] cmp name order[0] value[0] shape[0] order[1] value[1] shape[1]
0 nb_results -1 11 1.100000e+01 OK NaN NaN NaN NaN NaN NaN NaN
1 abs-diff 0 0 7.184343e-09 OK X 0.0 [[-0.0018820165, -0.044641636, -0.05147406, -0... (111, 10) 0.0 [[-0.0018820165277906047, -0.04464163650698914... (111, 10)
2 abs-diff 1 0 7.241096e-01 ERROR->=0.7 GPmean 5.0 [[136.0], [146.75], [156.875], [137.625], [143... (111, 1) 5.0 [[136.2904209381668], [147.37000865291338], [1... (111, 1)
3 abs-diff 2 0 7.150779e-09 OK kgpd_MatMulcst -1.0 [[-0.103593096, -0.009147094, 0.016280675, -0.... (10, 331) -1.0 [[-0.10359309315633439, -0.009147093429829445,... (10, 331)
4 abs-diff 3 0 2.693608e-04 e<0.001 kgpd_Addcst -1.0 [23321.936] (1,) -1.0 [23321.93527751423] (1,)
5 abs-diff 4 0 9.174340e-07 OK gpr_MatMulcst -1.0 [-6.7274747, 3.3635502, -4.675215, -7.969895, ... (331,) -1.0 [-6.7274746537081995, 3.363550107698292, -4.67... (331,)
6 abs-diff 5 0 0.000000e+00 OK gpr_Addcst -1.0 [[0.0]] (1, 1) -1.0 [[0.0]] (1, 1)
7 abs-diff 6 0 0.000000e+00 OK Re_Reshapecst -1.0 [-1, 1] (2,) -1.0 [-1, 1] (2,)
8 abs-diff 7 0 7.989149e-09 OK kgpd_Y0 1.0 [[0.013952837, 0.004027498, 0.0033139654, 0.01... (111, 331) 1.0 [[0.013952837286119372, 0.0040274979445440616,... (111, 331)
9 abs-diff 8 0 1.245899e-03 e<0.01 kgpd_C0 2.0 [[23321.95, 23321.94, 23321.94, 23321.953, 233... (111, 331) 2.0 [[23321.949230351514, 23321.939305012173, 2332... (111, 331)
10 abs-diff 9 0 7.241096e-01 ERROR->=0.7 gpr_Y0 3.0 [136.0, 146.75, 156.875, 137.625, 143.6875, 15... (111,) 3.0 [136.2904209381668, 147.37000865291338, 157.17... (111,)
11 abs-diff 10 0 7.241096e-01 ERROR->=0.7 gpr_C0 4.0 [[136.0, 146.75, 156.875, 137.625, 143.6875, 1... (1, 111) 4.0 [[136.2904209381668, 147.37000865291338, 157.1... (1, 111)

The differences really starts for output 'O0' after the matrix multiplication. This matrix melts different number with very different order of magnitudes and that alone explains the discrepencies with doubles and floats on that particular model.

%matplotlib inline
ax = df[['name', 'v[1]']].iloc[1:].set_index('name').plot(kind='bar', figsize=(14,4), logy=True)
ax.set_title("Relative differences for each output between float32 and "
"float64\nfor a GaussianProcessRegressor");

Before going further, let’s check how sensitive the trained model is about converting double into floats.

pg1 = gau.predict(X_test)
pg2 = gau.predict(X_test.astype(numpy.float32).astype(numpy.float64))
numpy.sort(numpy.sort(numpy.squeeze(pg1 - pg2)))[-5:]
array([2.71829776e-07, 2.75555067e-07, 2.98605300e-07, 3.28873284e-07,
3.92203219e-07])

Having float or double inputs should not matter. We confirm that with the model converted into ONNX.

p1 = oinf64.run({'X': X_test})['GPmean']
p2 = oinf64.run({'X': X_test.astype(numpy.float32).astype(numpy.float64)})['GPmean']
numpy.sort(numpy.sort(numpy.squeeze(p1 - p2)))[-5:]
array([2.71829776e-07, 2.75555067e-07, 2.98605300e-07, 3.28873284e-07,
3.92203219e-07])

Last verification.

sbs = side_by_side_by_values([(oinf64, {'X': X_test.astype(numpy.float32).astype(numpy.float64)}),
(oinf64, {'X': X_test.astype(numpy.float64)})])
df = DataFrame(sbs)
ax = df[['name', 'v[1]']].iloc[1:].set_index('name').plot(kind='bar', figsize=(14,4), logy=True)
ax.set_title("Relative differences for each output between float64 and float64 rounded to float32"
"\nfor a GaussianProcessRegressor");