.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/plot_ebegin_float_double.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code or to run this example in your browser via Binder .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_plot_ebegin_float_double.py: .. _l-example-discrepencies-float-double: Issues when switching to float ============================== .. index:: float, double, discrepencies Most models in :epkg:`scikit-learn` do computation with double, not float. Most models in deep learning use float because that's the most common situation with GPU. ONNX was initially created to facilitate the deployment of deep learning models and that explains why many converters assume the converted models should use float. That assumption does not usually harm the predictions, the conversion to float introduce small discrepencies compare to double predictions. That assumption is usually true if the prediction function is continuous, :math:`y = f(x)`, then :math:`dy = f'(x) dx`. We can determine an upper bound to the discrepencies : :math:`\Delta(y) \leqslant \sup_x \left\Vert f'(x)\right\Vert dx`. *dx* is the discrepency introduced by a float conversion, ``dx = x - numpy.float32(x)``. However, that's not the case for every model. A decision tree trained for a regression is not a continuous function. Therefore, even a small *dx* may introduce a huge discrepency. Let's look into an example which always produces discrepencies and some ways to overcome this situation. .. contents:: :local: More into the issue +++++++++++++++++++ The below example is built to fail. It contains integer features with different order of magnitude rounded to integer. A decision tree compares features to thresholds. In most cases, float and double comparison gives the same result. We denote :math:`[x]_{f32}` the conversion (or cast) ``numpy.float32(x)``. .. math:: x \leqslant y = [x]_{f32} \leqslant [y]_{f32} However, the probability that both comparisons give different results is not null. The following graph shows the discord areas. .. GENERATED FROM PYTHON SOURCE LINES 53-107 .. code-block:: default from skl2onnx.sklapi import CastRegressor from mlprodict.onnxrt import OnnxInference from mlprodict.onnx_conv import to_onnx as to_onnx_extended from mlprodict.sklapi import OnnxPipeline from skl2onnx.sklapi import CastTransformer from skl2onnx import to_onnx from onnxruntime import InferenceSession from sklearn.model_selection import train_test_split from sklearn.tree import DecisionTreeRegressor from sklearn.preprocessing import StandardScaler from sklearn.pipeline import Pipeline from sklearn.datasets import make_regression import numpy import matplotlib.pyplot as plt def area_mismatch_rule(N, delta, factor, rule=None): if rule is None: def rule(t): return numpy.float32(t) xst = [] yst = [] xsf = [] ysf = [] for x in range(-N, N): for y in range(-N, N): dx = (1. + x * delta) * factor dy = (1. + y * delta) * factor c1 = 1 if numpy.float64(dx) <= numpy.float64(dy) else 0 c2 = 1 if numpy.float32(dx) <= rule(dy) else 0 key = abs(c1 - c2) if key == 1: xsf.append(dx) ysf.append(dy) else: xst.append(dx) yst.append(dy) return xst, yst, xsf, ysf delta = 36e-10 factor = 1 xst, yst, xsf, ysf = area_mismatch_rule(100, delta, factor) fig, ax = plt.subplots(1, 1, figsize=(5, 5)) ax.plot(xst, yst, '.', label="agree") ax.plot(xsf, ysf, '.', label="disagree") ax.set_title("Region where x <= y and (float)x <= (float)y agree") ax.set_xlabel("x") ax.set_ylabel("y") ax.plot([min(xst), max(xst)], [min(yst), max(yst)], 'k--') ax.legend() .. image:: /auto_examples/images/sphx_glr_plot_ebegin_float_double_001.png :alt: Region where x <= y and (float)x <= (float)y agree :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out Out: .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 108-115 The pipeline and the data +++++++++++++++++++++++++ We can now build an example where the learned decision tree does many comparisons in this discord area. This is done by rounding features to integers, a frequent case happening when dealing with categorical features. .. GENERATED FROM PYTHON SOURCE LINES 115-135 .. code-block:: default X, y = make_regression(10000, 10) X_train, X_test, y_train, y_test = train_test_split(X, y) Xi_train, yi_train = X_train.copy(), y_train.copy() Xi_test, yi_test = X_test.copy(), y_test.copy() for i in range(X.shape[1]): Xi_train[:, i] = (Xi_train[:, i] * 2 ** i).astype(numpy.int64) Xi_test[:, i] = (Xi_test[:, i] * 2 ** i).astype(numpy.int64) max_depth = 10 model = Pipeline([ ('scaler', StandardScaler()), ('dt', DecisionTreeRegressor(max_depth=max_depth)) ]) model.fit(Xi_train, yi_train) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Pipeline(steps=[('scaler', StandardScaler()), ('dt', DecisionTreeRegressor(max_depth=10))]) .. GENERATED FROM PYTHON SOURCE LINES 136-142 The discrepencies +++++++++++++++++ Let's reuse the function implemented in the first example :ref:`l-diff-dicrepencies` and look into the conversion. .. GENERATED FROM PYTHON SOURCE LINES 142-162 .. code-block:: default def diff(p1, p2): p1 = p1.ravel() p2 = p2.ravel() d = numpy.abs(p2 - p1) return d.max(), (d / numpy.abs(p1)).max() onx = to_onnx(model, Xi_train[:1].astype(numpy.float32)) sess = InferenceSession(onx.SerializeToString()) X32 = Xi_test.astype(numpy.float32) skl = model.predict(X32) ort = sess.run(None, {'X': X32})[0] print(diff(skl, ort)) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none (103.68835894268832, 1.1565162623680272) .. GENERATED FROM PYTHON SOURCE LINES 163-196 The discrepencies are significant. The ONNX model keeps float at every step. .. blockdiag:: diagram { x_float32 -> normalizer -> y_float32 -> dtree -> z_float32 } In :epkg:`scikit-learn`: .. blockdiag:: diagram { x_float32 -> normalizer -> y_double -> dtree -> z_double } CastTransformer +++++++++++++++ We could try to use double everywhere. Unfortunately, :epkg:`ONNX ML Operators` only allows float coefficients for the operator *TreeEnsembleRegressor*. We may want to compromise by casting the output of the normalizer into float in the :epkg:`scikit-learn` pipeline. .. blockdiag:: diagram { x_float32 -> normalizer -> y_double -> cast -> y_float -> dtree -> z_float } .. GENERATED FROM PYTHON SOURCE LINES 196-206 .. code-block:: default model2 = Pipeline([ ('scaler', StandardScaler()), ('cast', CastTransformer()), ('dt', DecisionTreeRegressor(max_depth=max_depth)) ]) model2.fit(Xi_train, yi_train) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Pipeline(steps=[('scaler', StandardScaler()), ('cast', CastTransformer()), ('dt', DecisionTreeRegressor(max_depth=10))]) .. GENERATED FROM PYTHON SOURCE LINES 207-208 The discrepencies. .. GENERATED FROM PYTHON SOURCE LINES 208-218 .. code-block:: default onx2 = to_onnx(model2, Xi_train[:1].astype(numpy.float32)) sess2 = InferenceSession(onx2.SerializeToString()) skl2 = model2.predict(X32) ort2 = sess2.run(None, {'X': X32})[0] print(diff(skl2, ort2)) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none (103.68835894268831, 1.156516262368027) .. GENERATED FROM PYTHON SOURCE LINES 219-224 That still fails because the normalizer in :epkg:`scikit-learn` and in :epkg:`ONNX` use different types. The cast still happens and the *dx* is still here. To remove it, we need to use double in ONNX normalizer. .. GENERATED FROM PYTHON SOURCE LINES 224-243 .. code-block:: default model3 = Pipeline([ ('cast64', CastTransformer(dtype=numpy.float64)), ('scaler', StandardScaler()), ('cast', CastTransformer()), ('dt', DecisionTreeRegressor(max_depth=max_depth)) ]) model3.fit(Xi_train, yi_train) onx3 = to_onnx(model3, Xi_train[:1].astype(numpy.float32), options={StandardScaler: {'div': 'div_cast'}}) sess3 = InferenceSession(onx3.SerializeToString()) skl3 = model3.predict(X32) ort3 = sess3.run(None, {'X': X32})[0] print(diff(skl3, ort3)) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none (3.03531841154836e-05, 5.847676720404156e-08) .. GENERATED FROM PYTHON SOURCE LINES 244-267 It works. That also means that it is difficult to change the computation type when a pipeline includes a discontinuous function. It is better to keep the same types all along before using a decision tree. Sledgehammer ++++++++++++ The idea here is to always train the next step based on ONNX outputs. That way, every step of the pipeline is trained based on ONNX output. * Trains the first step. * Converts the step into ONNX * Computes ONNX outputs. * Trains the second step on these outputs. * Converts the second step into ONNX. * Merges it with the first step. * Computes ONNX outputs of the merged two first steps. * ... It is implemented in class :epkg:`OnnxPipeline`. .. GENERATED FROM PYTHON SOURCE LINES 267-276 .. code-block:: default model_onx = OnnxPipeline([ ('scaler', StandardScaler()), ('dt', DecisionTreeRegressor(max_depth=max_depth)) ]) model_onx.fit(Xi_train, yi_train) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none OnnxPipeline(steps=[('scaler', OnnxTransformer(onnx_bytes=b'\x08\x06\x12\x08skl2onnx\x1a\x081.9.3001"\x07ai.onnx(\x002\x00:\xf6\x01\n\xa6\x01\n\x01X\x12\x08variable\x1a\x06Scaler"\x06Scaler*=\n\x06offset=e\xcf\x8b:=\xd6\xea\x18\xbd=\x8c\x00\x14==hlB==\x8fx\x9f<=\x08\xf3\xc4\xbe=\xfd\xf6\xf5>=\x9aw\\\xbf=\x0c\x93)>=ElAA\xa0\x01\x06*<\n\x05scale=\x85D...\x8c>=\xfc\xb3\x07>=\xc6:\x81==\x96\xa2\x02==d\xbd\x7f<=\x07\n\xff;=\x86u\x81;=\xff+\xff:\xa0\x01\x06:\nai.onnx.ml\x12\x1emlprodict_ONNX(StandardScaler)Z\x11\n\x01X\x12\x0c\n\n\x08\x01\x12\x06\n\x00\n\x02\x08\nb\x18\n\x08variable\x12\x0c\n\n\x08\x01\x12\x06\n\x00\n\x02\x08\nB\x0e\n\nai.onnx.ml\x10\x01')), ('dt', DecisionTreeRegressor(max_depth=10))]) .. GENERATED FROM PYTHON SOURCE LINES 277-278 The conversion. .. GENERATED FROM PYTHON SOURCE LINES 278-288 .. code-block:: default onx4 = to_onnx(model_onx, Xi_train[:1].astype(numpy.float32)) sess4 = InferenceSession(onx4.SerializeToString()) skl4 = model_onx.predict(X32) ort4 = sess4.run(None, {'X': X32})[0] print(diff(skl4, ort4)) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none (3.03531841154836e-05, 5.847676720404156e-08) .. GENERATED FROM PYTHON SOURCE LINES 289-290 It works too in a more simple way. .. GENERATED FROM PYTHON SOURCE LINES 292-312 No discrepencies at all? ++++++++++++++++++++++++ Is it possible to get no error at all? There is one major obstacle: :epkg:`scikit-learn` stores the predicted values in every leave with double (`_tree.pyx - _get_value_ndarray `_), :epkg:`ONNX` defines the the predicted values as floats: :epkg:`TreeEnsembleRegressor`. What can we do to solve it? What if we could extend ONNX specifications to support double instead of floats. We reuse what was developped in example `Other way to convert `_ and a custom ONNX node `TreeEnsembleRegressorDouble `_. .. GENERATED FROM PYTHON SOURCE LINES 312-323 .. code-block:: default tree = DecisionTreeRegressor(max_depth=max_depth) tree.fit(Xi_train, yi_train) model_onx = to_onnx_extended(tree, Xi_train[:1].astype(numpy.float64), rewrite_ops=True) oinf5 = OnnxInference(model_onx, runtime='python_compiled') print(oinf5) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none OnnxInference(...) def compiled_run(dict_inputs): # inputs X = dict_inputs['X'] (variable, ) = n0_treeensembleregressordouble(X) return { 'variable': variable, } .. GENERATED FROM PYTHON SOURCE LINES 324-325 Let's measure the discrepencies. .. GENERATED FROM PYTHON SOURCE LINES 325-330 .. code-block:: default X64 = Xi_test.astype(numpy.float64) skl5 = tree.predict(X64) ort5 = oinf5.run({'X': X64})['variable'] .. GENERATED FROM PYTHON SOURCE LINES 331-332 Perfect, no discrepencies at all. .. GENERATED FROM PYTHON SOURCE LINES 332-335 .. code-block:: default print(diff(skl5, ort5)) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none (0.0, 0.0) .. GENERATED FROM PYTHON SOURCE LINES 336-344 CastRegressor +++++++++++++ The previous example demonstrated the type difference for the predicted values explains the small differences between :epkg:`scikit-learn` and :epkg:`onnxruntime`. But it does not with the current ONNX. Another option is to cast the the predictions into floats in the :epkg:`scikit-learn` pipeline. .. GENERATED FROM PYTHON SOURCE LINES 344-358 .. code-block:: default ctree = CastRegressor(DecisionTreeRegressor(max_depth=max_depth)) ctree.fit(Xi_train, yi_train) onx6 = to_onnx(ctree, Xi_train[:1].astype(numpy.float32)) sess6 = InferenceSession(onx6.SerializeToString()) skl6 = ctree.predict(X32) ort6 = sess6.run(None, {'X': X32})[0] print(diff(skl6, ort6)) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none (0.0, 0.0) .. GENERATED FROM PYTHON SOURCE LINES 359-360 Success! .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 6.665 seconds) .. _sphx_glr_download_auto_examples_plot_ebegin_float_double.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: binder-badge .. image:: images/binder_badge_logo.svg :target: https://mybinder.org/v2/gh/sdpython/onnxcustom/master?urlpath=lab/tree/notebooks/auto_examples/plot_ebegin_float_double.ipynb :alt: Launch binder :width: 150 px .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_ebegin_float_double.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_ebegin_float_double.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_