.. _onnxfloat32and64rst: =================================== ONNX graph, single or double floats =================================== .. only:: html **Links:** :download:`notebook `, :downloadlink:`html `, :download:`PDF `, :download:`python `, :downloadlink:`slides `, :githublink:`GitHub|_doc/notebooks/onnx_float32_and_64.ipynb|*` The notebook shows discrepencies obtained by using double floats instead of single float in two cases. The second one involves `GaussianProcessRegressor `__. .. code:: ipython3 from jyquickhelper import add_notebook_menu add_notebook_menu() .. contents:: :local: Simple case of a linear regression ---------------------------------- A linear regression is simply a matrix multiplication followed by an addition: :math:`Y=AX+B`. Let’s train one with `scikit-learn `__. .. code:: ipython3 from sklearn.linear_model import LinearRegression from sklearn.datasets import load_diabetes from sklearn.model_selection import train_test_split data = load_diabetes() X, y = data.data, data.target X_train, X_test, y_train, y_test = train_test_split(X, y) clr = LinearRegression() clr.fit(X_train, y_train) .. raw:: html
LinearRegression()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
.. code:: ipython3 clr.score(X_test, y_test) .. parsed-literal:: 0.48022823853163243 .. code:: ipython3 clr.coef_ .. parsed-literal:: array([ -3.66884712, -248.12455809, 503.47675603, 314.42722272, -937.79829646, 589.5139395 , 166.9937767 , 238.52080461, 810.51985926, 83.1649252 ]) .. code:: ipython3 clr.intercept_ .. parsed-literal:: 151.72119345267856 Let’s predict with *scikit-learn* and *python*. .. code:: ipython3 ypred = clr.predict(X_test) ypred[:5] .. parsed-literal:: array([ 65.19089869, 136.63206471, 197.78320816, 76.50979441, 120.17048032]) .. code:: ipython3 py_pred = X_test @ clr.coef_ + clr.intercept_ py_pred[:5] .. parsed-literal:: array([ 65.19089869, 136.63206471, 197.78320816, 76.50979441, 120.17048032]) .. code:: ipython3 clr.coef_.dtype, clr.intercept_.dtype .. parsed-literal:: (dtype('float64'), dtype('float64')) With ONNX --------- With *ONNX*, we would write this operation as follows… We still need to convert everything into single floats = float32. .. code:: ipython3 %load_ext mlprodict .. code:: ipython3 from skl2onnx.algebra.onnx_ops import OnnxMatMul, OnnxAdd import numpy onnx_fct = OnnxAdd(OnnxMatMul('X', clr.coef_.astype(numpy.float32), op_version=12), numpy.array([clr.intercept_], dtype=numpy.float32), output_names=['Y'], op_version=12) onnx_model32 = onnx_fct.to_onnx({'X': X_test.astype(numpy.float32)}) # add -l 1 if nothing shows up %onnxview onnx_model32 .. raw:: html
The next line uses a python runtime to compute the prediction. .. code:: ipython3 from mlprodict.onnxrt import OnnxInference oinf = OnnxInference(onnx_model32, inplace=False) ort_pred = oinf.run({'X': X_test.astype(numpy.float32)})['Y'] ort_pred[:5] .. parsed-literal:: array([ 65.190895, 136.63206 , 197.7832 , 76.509796, 120.17048 ], dtype=float32) And here is the same with `onnxruntime `__\ … .. code:: ipython3 oinf = OnnxInference(onnx_model32, runtime="onnxruntime1") ort_pred = oinf.run({'X': X_test.astype(numpy.float32)})['Y'] ort_pred[:5] .. parsed-literal:: array([ 65.190895, 136.63206 , 197.7832 , 76.509796, 120.17048 ], dtype=float32) With double instead of single float ----------------------------------- `ONNX `__ was originally designed for deep learning which usually uses floats but it does not mean cannot be used. Every number is converted into double floats. .. code:: ipython3 onnx_fct = OnnxAdd(OnnxMatMul('X', clr.coef_.astype(numpy.float64), op_version=12), numpy.array([clr.intercept_], dtype=numpy.float64), output_names=['Y'], op_version=12) onnx_model64 = onnx_fct.to_onnx({'X': X_test.astype(numpy.float64)}) And now the *python* runtime… .. code:: ipython3 oinf = OnnxInference(onnx_model64) ort_pred = oinf.run({'X': X_test})['Y'] ort_pred[:5] .. parsed-literal:: array([ 65.19089869, 136.63206471, 197.78320816, 76.50979441, 120.17048032]) And the *onnxruntime* version of it. .. code:: ipython3 oinf = OnnxInference(onnx_model64, runtime="onnxruntime1") ort_pred = oinf.run({'X': X_test.astype(numpy.float64)})['Y'] ort_pred[:5] .. parsed-literal:: array([ 65.19089869, 136.63206471, 197.78320816, 76.50979441, 120.17048032]) And now the GaussianProcessRegressor ------------------------------------ This shows a case .. code:: ipython3 from sklearn.gaussian_process import GaussianProcessRegressor from sklearn.gaussian_process.kernels import DotProduct gau = GaussianProcessRegressor(alpha=10, kernel=DotProduct()) gau.fit(X_train, y_train) .. raw:: html
GaussianProcessRegressor(alpha=10, kernel=DotProduct(sigma_0=1))
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
.. code:: ipython3 from mlprodict.onnx_conv import to_onnx onnxgau32 = to_onnx(gau, X_train.astype(numpy.float32)) oinf32 = OnnxInference(onnxgau32, runtime="python", inplace=False) ort_pred32 = oinf32.run({'X': X_test.astype(numpy.float32)})['GPmean'] numpy.squeeze(ort_pred32)[:25] .. parsed-literal:: array([136. , 146.75 , 156.875 , 137.625 , 143.6875, 157.25 , 137.625 , 155.4375, 157.125 , 176.1875, 154. , 144.6875, 152.875 , 163.0625, 134.5 , 169.25 , 143.4375, 156. , 147.9375, 147.5625, 143.5625, 139.5 , 167.3125, 162.8125, 157.5 ], dtype=float32) .. code:: ipython3 onnxgau64 = to_onnx(gau, X_train.astype(numpy.float64)) oinf64 = OnnxInference(onnxgau64, runtime="python", inplace=False) ort_pred64 = oinf64.run({'X': X_test.astype(numpy.float64)})['GPmean'] numpy.squeeze(ort_pred64)[:25] .. parsed-literal:: array([136.29042094, 147.37000865, 157.17181659, 137.37942361, 143.75809938, 157.26946743, 138.0470418 , 155.13779478, 157.13725317, 176.25699851, 154.58148006, 144.76382797, 152.92400576, 162.55328615, 135.01672829, 169.57752091, 144.15882691, 155.9305585 , 147.74172845, 147.95694225, 143.58627788, 139.44744308, 167.34231253, 162.89442931, 157.77991459]) The differences between the predictions for single floats and double floats… .. code:: ipython3 numpy.sort(numpy.sort(numpy.squeeze(ort_pred32 - ort_pred64)))[-5:] .. parsed-literal:: array([0.35428989, 0.37583714, 0.39413358, 0.46870174, 0.50921385]) Who’s right or wrong… The differences between the predictions with the original model… .. code:: ipython3 pred = gau.predict(X_test.astype(numpy.float64)) .. code:: ipython3 numpy.sort(numpy.sort(numpy.squeeze(ort_pred32 - pred)))[-5:] .. parsed-literal:: array([[-2.53819985e+01, -2.50722714e+01, -2.14450449e+01, -2.00524647e+01, -1.95723838e+01, -1.87025209e+01, -1.64673125e+01, -1.59125835e+01, -1.55697413e+01, -1.55154512e+01, -1.50184911e+01, -1.41483318e+01, -1.40623681e+01, -1.30606141e+01, -1.30289437e+01, -1.28485784e+01, -1.26562983e+01, -1.20194293e+01, -1.16782862e+01, -1.11241629e+01, -1.03175536e+01, -9.62974691e+00, -9.62801189e+00, -9.48419875e+00, -9.35995816e+00, -9.17521553e+00, -8.87705881e+00, -8.70495702e+00, -7.75760088e+00, -7.16953019e+00, -7.12544216e+00, -6.90491459e+00, -6.63742534e+00, -6.62542751e+00, -6.56554431e+00, -6.49066331e+00, -6.39446743e+00, -6.35915325e+00, -6.32373986e+00, -6.29681659e+00, -6.26225317e+00, -6.21640097e+00, -5.27455577e+00, -5.05555850e+00, -4.26279478e+00, -3.76824125e+00, -3.70648006e+00, -3.56134387e+00, -3.53965309e+00, -3.40999971e+00, -3.24941609e+00, -2.51781226e+00, -2.23055070e+00, -2.17549279e+00, -2.04900576e+00, -1.56856141e+00, -4.25747090e-01, -7.60603113e-02, 4.36406069e-01, 1.66633897e+00, 2.23583238e+00, 2.34094888e+00, 2.40322693e+00, 2.91733762e+00, 2.91805775e+00, 3.13327155e+00, 3.16460532e+00, 3.22593062e+00, 3.47124374e+00, 3.50499135e+00, 3.77589042e+00, 4.66605891e+00, 4.67679863e+00, 4.86402106e+00, 4.91204912e+00, 5.25337645e+00, 5.51913358e+00, 5.52309500e+00, 6.11117203e+00, 6.24359920e+00, 6.71617309e+00, 6.74395753e+00, 7.02078654e+00, 7.11690062e+00, 7.28872212e+00, 8.61841240e+00, 8.78399897e+00, 9.91193494e+00, 1.05662126e+01, 1.07342021e+01, 1.08129870e+01, 1.10002669e+01, 1.13923133e+01, 1.14275569e+01, 1.16316224e+01, 1.18395817e+01, 1.23588746e+01, 1.26277069e+01, 1.28279582e+01, 1.34955764e+01, 1.41485331e+01, 1.45845791e+01, 1.47292899e+01, 1.52692408e+01, 1.56045809e+01, 1.58122364e+01, 1.58582717e+01, 1.58897185e+01, 1.81760154e+01, 1.83160274e+01, 1.87306590e+01], [-3.06319985e+01, -3.03222714e+01, -2.66950449e+01, -2.53024647e+01, -2.48223838e+01, -2.39525209e+01, -2.17173125e+01, -2.11625835e+01, -2.08197413e+01, -2.07654512e+01, -2.02684911e+01, -1.93983318e+01, -1.93123681e+01, -1.83106141e+01, -1.82789437e+01, -1.80985784e+01, -1.79062983e+01, -1.72694293e+01, -1.69282862e+01, -1.63741629e+01, -1.55675536e+01, -1.48797469e+01, -1.48780119e+01, -1.47341988e+01, -1.46099582e+01, -1.44252155e+01, -1.41270588e+01, -1.39549570e+01, -1.30076009e+01, -1.24195302e+01, -1.23754422e+01, -1.21549146e+01, -1.18874253e+01, -1.18754275e+01, -1.18155443e+01, -1.17406633e+01, -1.16444674e+01, -1.16091533e+01, -1.15737399e+01, -1.15468166e+01, -1.15122532e+01, -1.14664010e+01, -1.05245558e+01, -1.03055585e+01, -9.51279478e+00, -9.01824125e+00, -8.95648006e+00, -8.81134387e+00, -8.78965309e+00, -8.65999971e+00, -8.49941609e+00, -7.76781226e+00, -7.48055070e+00, -7.42549279e+00, -7.29900576e+00, -6.81856141e+00, -5.67574709e+00, -5.32606031e+00, -4.81359393e+00, -3.58366103e+00, -3.01416762e+00, -2.90905112e+00, -2.84677307e+00, -2.33266238e+00, -2.33194225e+00, -2.11672845e+00, -2.08539468e+00, -2.02406938e+00, -1.77875626e+00, -1.74500865e+00, -1.47410958e+00, -5.83941087e-01, -5.73201374e-01, -3.85978940e-01, -3.37950879e-01, 3.37644562e-03, 2.69133580e-01, 2.73095000e-01, 8.61172031e-01, 9.93599197e-01, 1.46617309e+00, 1.49395753e+00, 1.77078654e+00, 1.86690062e+00, 2.03872212e+00, 3.36841240e+00, 3.53399897e+00, 4.66193494e+00, 5.31621258e+00, 5.48420209e+00, 5.56298697e+00, 5.75026691e+00, 6.14231332e+00, 6.17755692e+00, 6.38162237e+00, 6.58958168e+00, 7.10887459e+00, 7.37770688e+00, 7.57795820e+00, 8.24557639e+00, 8.89853314e+00, 9.33457906e+00, 9.47928989e+00, 1.00192408e+01, 1.03545809e+01, 1.05622364e+01, 1.06082717e+01, 1.06397185e+01, 1.29260154e+01, 1.30660274e+01, 1.34806590e+01], [-4.40069985e+01, -4.36972714e+01, -4.00700449e+01, -3.86774647e+01, -3.81973838e+01, -3.73275209e+01, -3.50923125e+01, -3.45375835e+01, -3.41947413e+01, -3.41404512e+01, -3.36434911e+01, -3.27733318e+01, -3.26873681e+01, -3.16856141e+01, -3.16539437e+01, -3.14735784e+01, -3.12812983e+01, -3.06444293e+01, -3.03032862e+01, -2.97491629e+01, -2.89425536e+01, -2.82547469e+01, -2.82530119e+01, -2.81091988e+01, -2.79849582e+01, -2.78002155e+01, -2.75020588e+01, -2.73299570e+01, -2.63826009e+01, -2.57945302e+01, -2.57504422e+01, -2.55299146e+01, -2.52624253e+01, -2.52504275e+01, -2.51905443e+01, -2.51156633e+01, -2.50194674e+01, -2.49841533e+01, -2.49487399e+01, -2.49218166e+01, -2.48872532e+01, -2.48414010e+01, -2.38995558e+01, -2.36805585e+01, -2.28877948e+01, -2.23932412e+01, -2.23314801e+01, -2.21863439e+01, -2.21646531e+01, -2.20349997e+01, -2.18744161e+01, -2.11428123e+01, -2.08555507e+01, -2.08004928e+01, -2.06740058e+01, -2.01935614e+01, -1.90507471e+01, -1.87010603e+01, -1.81885939e+01, -1.69586610e+01, -1.63891676e+01, -1.62840511e+01, -1.62217731e+01, -1.57076624e+01, -1.57069422e+01, -1.54917284e+01, -1.54603947e+01, -1.53990694e+01, -1.51537563e+01, -1.51200087e+01, -1.48491096e+01, -1.39589411e+01, -1.39482014e+01, -1.37609789e+01, -1.37129509e+01, -1.33716236e+01, -1.31058664e+01, -1.31019050e+01, -1.25138280e+01, -1.23814008e+01, -1.19088269e+01, -1.18810425e+01, -1.16042135e+01, -1.15080994e+01, -1.13362779e+01, -1.00065876e+01, -9.84100103e+00, -8.71306506e+00, -8.05878742e+00, -7.89079791e+00, -7.81201303e+00, -7.62473309e+00, -7.23268668e+00, -7.19744308e+00, -6.99337763e+00, -6.78541832e+00, -6.26612541e+00, -5.99729312e+00, -5.79704180e+00, -5.12942361e+00, -4.47646686e+00, -4.04042094e+00, -3.89571011e+00, -3.35575918e+00, -3.02041912e+00, -2.81276360e+00, -2.76672829e+00, -2.73528153e+00, -4.48984564e-01, -3.08972594e-01, 1.05659015e-01], [-2.76319985e+01, -2.73222714e+01, -2.36950449e+01, -2.23024647e+01, -2.18223838e+01, -2.09525209e+01, -1.87173125e+01, -1.81625835e+01, -1.78197413e+01, -1.77654512e+01, -1.72684911e+01, -1.63983318e+01, -1.63123681e+01, -1.53106141e+01, -1.52789437e+01, -1.50985784e+01, -1.49062983e+01, -1.42694293e+01, -1.39282862e+01, -1.33741629e+01, -1.25675536e+01, -1.18797469e+01, -1.18780119e+01, -1.17341988e+01, -1.16099582e+01, -1.14252155e+01, -1.11270588e+01, -1.09549570e+01, -1.00076009e+01, -9.41953019e+00, -9.37544216e+00, -9.15491459e+00, -8.88742534e+00, -8.87542751e+00, -8.81554431e+00, -8.74066331e+00, -8.64446743e+00, -8.60915325e+00, -8.57373986e+00, -8.54681659e+00, -8.51225317e+00, -8.46640097e+00, -7.52455577e+00, -7.30555850e+00, -6.51279478e+00, -6.01824125e+00, -5.95648006e+00, -5.81134387e+00, -5.78965309e+00, -5.65999971e+00, -5.49941609e+00, -4.76781226e+00, -4.48055070e+00, -4.42549279e+00, -4.29900576e+00, -3.81856141e+00, -2.67574709e+00, -2.32606031e+00, -1.81359393e+00, -5.83661025e-01, -1.41676202e-02, 9.09488820e-02, 1.53226928e-01, 6.67337617e-01, 6.68057753e-01, 8.83271551e-01, 9.14605317e-01, 9.75930622e-01, 1.22124374e+00, 1.25499135e+00, 1.52589042e+00, 2.41605891e+00, 2.42679863e+00, 2.61402106e+00, 2.66204912e+00, 3.00337645e+00, 3.26913358e+00, 3.27309500e+00, 3.86117203e+00, 3.99359920e+00, 4.46617309e+00, 4.49395753e+00, 4.77078654e+00, 4.86690062e+00, 5.03872212e+00, 6.36841240e+00, 6.53399897e+00, 7.66193494e+00, 8.31621258e+00, 8.48420209e+00, 8.56298697e+00, 8.75026691e+00, 9.14231332e+00, 9.17755692e+00, 9.38162237e+00, 9.58958168e+00, 1.01088746e+01, 1.03777069e+01, 1.05779582e+01, 1.12455764e+01, 1.18985331e+01, 1.23345791e+01, 1.24792899e+01, 1.30192408e+01, 1.33545809e+01, 1.35622364e+01, 1.36082717e+01, 1.36397185e+01, 1.59260154e+01, 1.60660274e+01, 1.64806590e+01], [-1.05069985e+01, -1.01972714e+01, -6.57004494e+00, -5.17746472e+00, -4.69738380e+00, -3.82752091e+00, -1.59231253e+00, -1.03758348e+00, -6.94741336e-01, -6.40451168e-01, -1.43491072e-01, 7.26668170e-01, 8.12631942e-01, 1.81438594e+00, 1.84605634e+00, 2.02642158e+00, 2.21870174e+00, 2.85557069e+00, 3.19671385e+00, 3.75083714e+00, 4.55744644e+00, 5.24525309e+00, 5.24698811e+00, 5.39080125e+00, 5.51504184e+00, 5.69978447e+00, 5.99794119e+00, 6.17004298e+00, 7.11739912e+00, 7.70546981e+00, 7.74955784e+00, 7.97008541e+00, 8.23757466e+00, 8.24957249e+00, 8.30945569e+00, 8.38433669e+00, 8.48053257e+00, 8.51584675e+00, 8.55126014e+00, 8.57818341e+00, 8.61274683e+00, 8.65859903e+00, 9.60044423e+00, 9.81944150e+00, 1.06122052e+01, 1.11067588e+01, 1.11685199e+01, 1.13136561e+01, 1.13353469e+01, 1.14650003e+01, 1.16255839e+01, 1.23571877e+01, 1.26444493e+01, 1.26995072e+01, 1.28259942e+01, 1.33064386e+01, 1.44492529e+01, 1.47989397e+01, 1.53114061e+01, 1.65413390e+01, 1.71108324e+01, 1.72159489e+01, 1.72782269e+01, 1.77923376e+01, 1.77930578e+01, 1.80082716e+01, 1.80396053e+01, 1.81009306e+01, 1.83462437e+01, 1.83799913e+01, 1.86508904e+01, 1.95410589e+01, 1.95517986e+01, 1.97390211e+01, 1.97870491e+01, 2.01283764e+01, 2.03941336e+01, 2.03980950e+01, 2.09861720e+01, 2.11185992e+01, 2.15911731e+01, 2.16189575e+01, 2.18957865e+01, 2.19919006e+01, 2.21637221e+01, 2.34934124e+01, 2.36589990e+01, 2.47869349e+01, 2.54412126e+01, 2.56092021e+01, 2.56879870e+01, 2.58752669e+01, 2.62673133e+01, 2.63025569e+01, 2.65066224e+01, 2.67145817e+01, 2.72338746e+01, 2.75027069e+01, 2.77029582e+01, 2.83705764e+01, 2.90235331e+01, 2.94595791e+01, 2.96042899e+01, 3.01442408e+01, 3.04795809e+01, 3.06872364e+01, 3.07332717e+01, 3.07647185e+01, 3.30510154e+01, 3.31910274e+01, 3.36056590e+01]]) .. code:: ipython3 numpy.sort(numpy.sort(numpy.squeeze(ort_pred64 - pred)))[-5:] .. parsed-literal:: array([[-25.3059382 , -24.9962111 , -21.36898463, -19.9764044 , -19.49632349, -18.6264606 , -16.39125222, -15.83652317, -15.49368102, -15.43939086, -14.94243076, -14.07227152, -13.98630775, -12.98455375, -12.95288335, -12.77251811, -12.58023795, -11.943369 , -11.60222584, -11.04810255, -10.24149325, -9.5536866 , -9.55195157, -9.40813844, -9.28389785, -9.09915522, -8.80099849, -8.62889671, -7.68154057, -7.09346988, -7.04938185, -6.82885428, -6.56136503, -6.54936719, -6.489484 , -6.414603 , -6.31840711, -6.28309294, -6.24767955, -6.22075628, -6.18619286, -6.14034066, -5.19849546, -4.97949819, -4.18673447, -3.69218093, -3.63041975, -3.48528356, -3.46359277, -3.33393939, -3.17335578, -2.44175195, -2.15449039, -2.09943248, -1.97294545, -1.4925011 , -0.34968678, 0. , 0.51246638, 1.74239929, 2.31189269, 2.41700919, 2.47928724, 2.99339793, 2.99411806, 3.20933186, 3.24066563, 3.30199093, 3.54730405, 3.58105166, 3.85195073, 4.74211922, 4.75285894, 4.94008137, 4.98810943, 5.32943676, 5.59519389, 5.59915531, 6.18723234, 6.31965951, 6.7922334 , 6.82001784, 7.09684685, 7.19296094, 7.36478243, 8.69447271, 8.86005928, 9.98799525, 10.64227289, 10.8102624 , 10.88904728, 11.07632722, 11.46837363, 11.50361724, 11.70768268, 11.91564199, 12.4349349 , 12.70376719, 12.90401852, 13.5716367 , 14.22459346, 14.66063937, 14.8053502 , 15.34530113, 15.68064119, 15.88829671, 15.93433202, 15.96577878, 18.25207575, 18.39208772, 18.80671933], [-30.63537495, -30.32564785, -26.69842139, -25.30584116, -24.82576025, -23.95589736, -21.72068897, -21.16595993, -20.82311778, -20.76882761, -20.27186752, -19.40170828, -19.3157445 , -18.31399051, -18.28232011, -18.10195487, -17.90967471, -17.27280576, -16.9316626 , -16.37753931, -15.57093001, -14.88312335, -14.88138833, -14.7375752 , -14.61333461, -14.42859198, -14.13043525, -13.95833347, -13.01097732, -12.42290664, -12.37881861, -12.15829103, -11.89080179, -11.87880395, -11.81892076, -11.74403975, -11.64784387, -11.6125297 , -11.57711631, -11.55019304, -11.51562962, -11.46977742, -10.52793221, -10.30893495, -9.51617123, -9.02161769, -8.95985651, -8.81472032, -8.79302953, -8.66337615, -8.50279254, -7.77118871, -7.48392715, -7.42886923, -7.3023822 , -6.82193786, -5.67912354, -5.32943676, -4.81697038, -3.58703747, -3.01754407, -2.91242756, -2.85014952, -2.33603883, -2.33531869, -2.1201049 , -2.08877113, -2.02744582, -1.78213271, -1.7483851 , -1.47748603, -0.58731753, -0.57657782, -0.38935539, -0.34132732, 0. , 0.26575713, 0.26971855, 0.85779559, 0.99022275, 1.46279664, 1.49058108, 1.76741009, 1.86352418, 2.03534568, 3.36503596, 3.53062253, 4.6585585 , 5.31283614, 5.48082565, 5.55961053, 5.74689046, 6.13893687, 6.17418048, 6.37824592, 6.58620523, 7.10549815, 7.37433043, 7.57458176, 8.24219994, 8.8951567 , 9.33120262, 9.47591344, 10.01586437, 10.35120443, 10.55885996, 10.60489527, 10.63634202, 12.92263899, 13.06265096, 13.47728257], [-43.69802592, -43.38829881, -39.76107235, -38.36849212, -37.88841121, -37.01854832, -34.78333993, -34.22861089, -33.88576874, -33.83147857, -33.33451848, -32.46435924, -32.37839546, -31.37664147, -31.34497107, -31.16460583, -30.97232567, -30.33545672, -29.99431356, -29.44019027, -28.63358097, -27.94577431, -27.94403929, -27.80022616, -27.67598557, -27.49124294, -27.19308621, -27.02098443, -26.07362829, -25.4855576 , -25.44146957, -25.22094199, -24.95345275, -24.94145491, -24.88157172, -24.80669071, -24.71049483, -24.67518066, -24.63976727, -24.612844 , -24.57828058, -24.53242838, -23.59058318, -23.37158591, -22.57882219, -22.08426865, -22.02250747, -21.87737128, -21.85568049, -21.72602711, -21.5654435 , -20.83383967, -20.54657811, -20.49152019, -20.36503316, -19.88458882, -18.7417745 , -18.39208772, -17.87962134, -16.64968843, -16.08019503, -15.97507852, -15.91280048, -15.39868979, -15.39796965, -15.18275586, -15.15142209, -15.09009678, -14.84478367, -14.81103606, -14.54013699, -13.64996849, -13.63922878, -13.45200635, -13.40397829, -13.06265096, -12.79689383, -12.79293241, -12.20485538, -12.07242821, -11.59985432, -11.57206988, -11.29524087, -11.19912678, -11.02730529, -9.69761501, -9.53202843, -8.40409246, -7.74981482, -7.58182532, -7.50304043, -7.3157605 , -6.92371409, -6.88847048, -6.68440504, -6.47644573, -5.95715282, -5.68832053, -5.4880692 , -4.82045102, -4.16749426, -3.73144834, -3.58673752, -3.04678659, -2.71144653, -2.50379101, -2.45775569, -2.42630894, -0.14001197, 0. , 0.41463161], [-27.61783089, -27.30810379, -23.68087732, -22.2882971 , -21.80821618, -20.93835329, -18.70314491, -18.14841586, -17.80557372, -17.75128355, -17.25432345, -16.38416421, -16.29820044, -15.29644644, -15.26477604, -15.0844108 , -14.89213064, -14.25526169, -13.91411853, -13.35999524, -12.55338594, -11.86557929, -11.86384427, -11.72003113, -11.59579054, -11.41104791, -11.11289119, -10.9407894 , -9.99343326, -9.40536257, -9.36127454, -9.14074697, -8.87325772, -8.86125988, -8.80137669, -8.72649569, -8.63029981, -8.59498563, -8.55957224, -8.53264897, -8.49808555, -8.45223335, -7.51038815, -7.29139088, -6.49862716, -6.00407362, -5.94231244, -5.79717625, -5.77548547, -5.64583209, -5.48524847, -4.75364464, -4.46638308, -4.41132517, -4.28483814, -3.80439379, -2.66157947, -2.31189269, -1.79942631, -0.5694934 , 0. , 0.1051165 , 0.16739455, 0.68150524, 0.68222537, 0.89743917, 0.92877294, 0.99009824, 1.23541136, 1.26915897, 1.54005804, 2.43022653, 2.44096625, 2.62818868, 2.67621674, 3.01754407, 3.2833012 , 3.28726262, 3.87533965, 4.00776682, 4.48034071, 4.50812515, 4.78495416, 4.88106824, 5.05288974, 6.38258002, 6.54816659, 7.67610256, 8.3303802 , 8.49836971, 8.57715459, 8.76443453, 9.15648094, 9.19172454, 9.39578999, 9.6037493 , 10.12304221, 10.3918745 , 10.59212582, 11.25974401, 11.91270076, 12.34874668, 12.49345751, 13.03340844, 13.3687485 , 13.57640402, 13.62243933, 13.65388609, 15.94018306, 16.08019503, 16.49482663], [-10.36350744, -10.05378034, -6.42655387, -5.03397364, -4.55389273, -3.68402984, -1.44882146, -0.89409241, -0.55125026, -0.4969601 , 0. , 0.87015924, 0.95612301, 1.95787701, 1.98954741, 2.16991265, 2.36219281, 2.99906176, 3.34020492, 3.89432821, 4.70093751, 5.38874416, 5.39047919, 5.53429232, 5.65853291, 5.84327554, 6.14143227, 6.31353405, 7.26089019, 7.84896088, 7.89304891, 8.11357648, 8.38106573, 8.39306357, 8.45294676, 8.52782776, 8.62402365, 8.65933782, 8.69475121, 8.72167448, 8.7562379 , 8.8020901 , 9.7439353 , 9.96293257, 10.75569629, 11.25024983, 11.31201101, 11.4571472 , 11.47883799, 11.60849137, 11.76907498, 12.50067881, 12.78794037, 12.84299828, 12.96948531, 13.44992966, 14.59274398, 14.94243076, 15.45489714, 16.68483005, 17.25432345, 17.35943995, 17.421718 , 17.93582869, 17.93654882, 18.15176262, 18.18309639, 18.24442169, 18.48973481, 18.52348242, 18.79438149, 19.68454998, 19.6952897 , 19.88251213, 19.93054019, 20.27186752, 20.53762465, 20.54158607, 21.1296631 , 21.26209027, 21.73466416, 21.7624486 , 22.03927761, 22.1353917 , 22.30721319, 23.63690347, 23.80249004, 24.93042601, 25.58470365, 25.75269316, 25.83147804, 26.01875798, 26.41080439, 26.446048 , 26.65011344, 26.85807275, 27.37736566, 27.64619795, 27.84644928, 28.51406746, 29.16702422, 29.60307013, 29.74778096, 30.28773189, 30.62307195, 30.83072747, 30.87676278, 30.90820954, 33.19450651, 33.33451848, 33.74915009]]) Double predictions clearly wins. .. code:: ipython3 # add -l 1 if nothing shows up %onnxview onnxgau64 .. raw:: html
Saves… ------ Let’s keep track of it. .. code:: ipython3 with open("gpr_dot_product_boston_32.onnx", "wb") as f: f.write(onnxgau32.SerializePartialToString()) from IPython.display import FileLink FileLink('gpr_dot_product_boston_32.onnx') .. raw:: html gpr_dot_product_boston_32.onnx
.. code:: ipython3 with open("gpr_dot_product_boston_64.onnx", "wb") as f: f.write(onnxgau64.SerializePartialToString()) FileLink('gpr_dot_product_boston_64.onnx') .. raw:: html gpr_dot_product_boston_64.onnx
Side by side ------------ We may wonder where the discrepencies start. But for that, we need to do a side by side. .. code:: ipython3 from mlprodict.onnxrt.validate.side_by_side import side_by_side_by_values sbs = side_by_side_by_values([(oinf32, {'X': X_test.astype(numpy.float32)}), (oinf64, {'X': X_test.astype(numpy.float64)})]) from pandas import DataFrame df = DataFrame(sbs) # dfd = df.drop(['value[0]', 'value[1]', 'value[2]'], axis=1).copy() df .. raw:: html
metric step v[0] v[1] cmp name order[0] value[0] shape[0] order[1] value[1] shape[1]
0 nb_results -1 11 1.100000e+01 OK NaN NaN NaN NaN NaN NaN NaN
1 abs-diff 0 0 7.184343e-09 OK X 0.0 [[-0.0018820165, -0.044641636, -0.05147406, -0... (111, 10) 0.0 [[-0.0018820165277906047, -0.04464163650698914... (111, 10)
2 abs-diff 1 0 7.241096e-01 ERROR->=0.7 GPmean 5.0 [[136.0], [146.75], [156.875], [137.625], [143... (111, 1) 5.0 [[136.2904209381668], [147.37000865291338], [1... (111, 1)
3 abs-diff 2 0 7.150779e-09 OK kgpd_MatMulcst -1.0 [[-0.103593096, -0.009147094, 0.016280675, -0.... (10, 331) -1.0 [[-0.10359309315633439, -0.009147093429829445,... (10, 331)
4 abs-diff 3 0 2.693608e-04 e<0.001 kgpd_Addcst -1.0 [23321.936] (1,) -1.0 [23321.93527751423] (1,)
5 abs-diff 4 0 9.174340e-07 OK gpr_MatMulcst -1.0 [-6.7274747, 3.3635502, -4.675215, -7.969895, ... (331,) -1.0 [-6.7274746537081995, 3.363550107698292, -4.67... (331,)
6 abs-diff 5 0 0.000000e+00 OK gpr_Addcst -1.0 [[0.0]] (1, 1) -1.0 [[0.0]] (1, 1)
7 abs-diff 6 0 0.000000e+00 OK Re_Reshapecst -1.0 [-1, 1] (2,) -1.0 [-1, 1] (2,)
8 abs-diff 7 0 7.989149e-09 OK kgpd_Y0 1.0 [[0.013952837, 0.004027498, 0.0033139654, 0.01... (111, 331) 1.0 [[0.013952837286119372, 0.0040274979445440616,... (111, 331)
9 abs-diff 8 0 1.245899e-03 e<0.01 kgpd_C0 2.0 [[23321.95, 23321.94, 23321.94, 23321.953, 233... (111, 331) 2.0 [[23321.949230351514, 23321.939305012173, 2332... (111, 331)
10 abs-diff 9 0 7.241096e-01 ERROR->=0.7 gpr_Y0 3.0 [136.0, 146.75, 156.875, 137.625, 143.6875, 15... (111,) 3.0 [136.2904209381668, 147.37000865291338, 157.17... (111,)
11 abs-diff 10 0 7.241096e-01 ERROR->=0.7 gpr_C0 4.0 [[136.0, 146.75, 156.875, 137.625, 143.6875, 1... (1, 111) 4.0 [[136.2904209381668, 147.37000865291338, 157.1... (1, 111)
The differences really starts for output ``'O0'`` after the matrix multiplication. This matrix melts different number with very different order of magnitudes and that alone explains the discrepencies with doubles and floats on that particular model. .. code:: ipython3 %matplotlib inline ax = df[['name', 'v[1]']].iloc[1:].set_index('name').plot(kind='bar', figsize=(14,4), logy=True) ax.set_title("Relative differences for each output between float32 and " "float64\nfor a GaussianProcessRegressor"); .. image:: onnx_float32_and_64_42_0.png Before going further, let’s check how sensitive the trained model is about converting double into floats. .. code:: ipython3 pg1 = gau.predict(X_test) pg2 = gau.predict(X_test.astype(numpy.float32).astype(numpy.float64)) numpy.sort(numpy.sort(numpy.squeeze(pg1 - pg2)))[-5:] .. parsed-literal:: array([2.71829776e-07, 2.75555067e-07, 2.98605300e-07, 3.28873284e-07, 3.92203219e-07]) Having float or double inputs should not matter. We confirm that with the model converted into ONNX. .. code:: ipython3 p1 = oinf64.run({'X': X_test})['GPmean'] p2 = oinf64.run({'X': X_test.astype(numpy.float32).astype(numpy.float64)})['GPmean'] numpy.sort(numpy.sort(numpy.squeeze(p1 - p2)))[-5:] .. parsed-literal:: array([2.71829776e-07, 2.75555067e-07, 2.98605300e-07, 3.28873284e-07, 3.92203219e-07]) Last verification. .. code:: ipython3 sbs = side_by_side_by_values([(oinf64, {'X': X_test.astype(numpy.float32).astype(numpy.float64)}), (oinf64, {'X': X_test.astype(numpy.float64)})]) df = DataFrame(sbs) ax = df[['name', 'v[1]']].iloc[1:].set_index('name').plot(kind='bar', figsize=(14,4), logy=True) ax.set_title("Relative differences for each output between float64 and float64 rounded to float32" "\nfor a GaussianProcessRegressor"); .. image:: onnx_float32_and_64_48_0.png