.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_tutorial/plot_gexternal_xgboost.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_tutorial_plot_gexternal_xgboost.py: .. _example-xgboost: Convert a pipeline with a XGBoost model ======================================== .. index:: XGBoost :epkg:`sklearn-onnx` only converts :epkg:`scikit-learn` models into :epkg:`ONNX` but many libraries implement :epkg:`scikit-learn` API so that their models can be included in a :epkg:`scikit-learn` pipeline. This example considers a pipeline including a :epkg:`XGBoost` model. :epkg:`sklearn-onnx` can convert the whole pipeline as long as it knows the converter associated to a *XGBClassifier*. Let's see how to do it. .. contents:: :local: Train a XGBoost classifier ++++++++++++++++++++++++++ .. GENERATED FROM PYTHON SOURCE LINES 26-75 .. code-block:: default from pyquickhelper.helpgen.graphviz_helper import plot_graphviz from mlprodict.onnxrt import OnnxInference import numpy import onnxruntime as rt from sklearn.datasets import load_iris, load_diabetes, make_classification from sklearn.model_selection import train_test_split from sklearn.pipeline import Pipeline from sklearn.preprocessing import StandardScaler from xgboost import XGBClassifier, XGBRegressor, DMatrix, train as train_xgb from skl2onnx.common.data_types import FloatTensorType from skl2onnx import convert_sklearn, to_onnx, update_registered_converter from skl2onnx.common.shape_calculator import ( calculate_linear_classifier_output_shapes, calculate_linear_regressor_output_shapes) from onnxmltools.convert.xgboost.operator_converters.XGBoost import ( convert_xgboost) from onnxmltools.convert import convert_xgboost as convert_xgboost_booster data = load_iris() X = data.data[:, :2] y = data.target ind = numpy.arange(X.shape[0]) numpy.random.shuffle(ind) X = X[ind, :].copy() y = y[ind].copy() pipe = Pipeline([('scaler', StandardScaler()), ('xgb', XGBClassifier(n_estimators=3))]) pipe.fit(X, y) # The conversion fails but it is expected. try: convert_sklearn(pipe, 'pipeline_xgboost', [('input', FloatTensorType([None, 2]))], target_opset={'': 12, 'ai.onnx.ml': 2}) except Exception as e: print(e) # The error message tells no converter was found # for :epkg:`XGBoost` models. By default, :epkg:`sklearn-onnx` # only handles models from :epkg:`scikit-learn` but it can # be extended to every model following :epkg:`scikit-learn` # API as long as the module knows there exists a converter # for every model used in a pipeline. That's why # we need to register a converter. .. GENERATED FROM PYTHON SOURCE LINES 76-87 Register the converter for XGBClassifier ++++++++++++++++++++++++++++++++++++++++ The converter is implemented in :epkg:`onnxmltools`: `onnxmltools...XGBoost.py `_. and the shape calculator: `onnxmltools...Classifier.py `_. .. GENERATED FROM PYTHON SOURCE LINES 87-93 .. code-block:: default update_registered_converter( XGBClassifier, 'XGBoostXGBClassifier', calculate_linear_classifier_output_shapes, convert_xgboost, options={'nocl': [True, False], 'zipmap': [True, False, 'columns']}) .. GENERATED FROM PYTHON SOURCE LINES 94-96 Convert again +++++++++++++ .. GENERATED FROM PYTHON SOURCE LINES 96-106 .. code-block:: default model_onnx = convert_sklearn( pipe, 'pipeline_xgboost', [('input', FloatTensorType([None, 2]))], target_opset={'': 12, 'ai.onnx.ml': 2}) # And save. with open("pipeline_xgboost.onnx", "wb") as f: f.write(model_onnx.SerializeToString()) .. GENERATED FROM PYTHON SOURCE LINES 107-111 Compare the predictions +++++++++++++++++++++++ Predictions with XGBoost. .. GENERATED FROM PYTHON SOURCE LINES 111-115 .. code-block:: default print("predict", pipe.predict(X[:5])) print("predict_proba", pipe.predict_proba(X[:1])) .. rst-class:: sphx-glr-script-out .. code-block:: none predict [1 0 2 1 1] predict_proba [[0.17186314 0.5699681 0.25816876]] .. GENERATED FROM PYTHON SOURCE LINES 116-117 Predictions with onnxruntime. .. GENERATED FROM PYTHON SOURCE LINES 117-123 .. code-block:: default sess = rt.InferenceSession("pipeline_xgboost.onnx") pred_onx = sess.run(None, {"input": X[:5].astype(numpy.float32)}) print("predict", pred_onx[0]) print("predict_proba", pred_onx[1][:1]) .. rst-class:: sphx-glr-script-out .. code-block:: none predict [1 0 2 1 1] predict_proba [{0: 0.1718631386756897, 1: 0.5699681043624878, 2: 0.2581687569618225}] .. GENERATED FROM PYTHON SOURCE LINES 124-126 Final graph +++++++++++ .. GENERATED FROM PYTHON SOURCE LINES 126-134 .. code-block:: default oinf = OnnxInference(model_onnx) ax = plot_graphviz(oinf.to_dot()) ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) .. image-sg:: /auto_tutorial/images/sphx_glr_plot_gexternal_xgboost_001.png :alt: plot gexternal xgboost :srcset: /auto_tutorial/images/sphx_glr_plot_gexternal_xgboost_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 135-137 Same example with XGBRegressor ++++++++++++++++++++++++++++++ .. GENERATED FROM PYTHON SOURCE LINES 137-154 .. code-block:: default update_registered_converter( XGBRegressor, 'XGBoostXGBRegressor', calculate_linear_regressor_output_shapes, convert_xgboost) data = load_diabetes() x = data.data y = data.target X_train, X_test, y_train, _ = train_test_split(x, y, test_size=0.5) pipe = Pipeline([('scaler', StandardScaler()), ('xgb', XGBRegressor(n_estimators=3))]) pipe.fit(X_train, y_train) print("predict", pipe.predict(X_test[:5])) .. rst-class:: sphx-glr-script-out .. code-block:: none predict [147.65546 84.55453 37.157578 103.809296 105.27902 ] .. GENERATED FROM PYTHON SOURCE LINES 155-156 ONNX .. GENERATED FROM PYTHON SOURCE LINES 156-164 .. code-block:: default onx = to_onnx(pipe, X_train.astype(numpy.float32), target_opset={'': 12, 'ai.onnx.ml': 2}) sess = rt.InferenceSession(onx.SerializeToString()) pred_onx = sess.run(None, {"X": X_test[:5].astype(numpy.float32)}) print("predict", pred_onx[0].ravel()) .. rst-class:: sphx-glr-script-out .. code-block:: none predict [147.65546 84.55453 37.157578 103.809296 105.27902 ] .. GENERATED FROM PYTHON SOURCE LINES 165-167 Some discrepencies may appear. In that case, you should read :ref:`l-example-discrepencies-float-double`. .. GENERATED FROM PYTHON SOURCE LINES 169-175 Same with a Booster +++++++++++++++++++ A booster cannot be inserted in a pipeline. It requires a different conversion function because it does not follow :epkg:`scikit-learn` API. .. GENERATED FROM PYTHON SOURCE LINES 175-204 .. code-block:: default x, y = make_classification(n_classes=2, n_features=5, n_samples=100, random_state=42, n_informative=3) X_train, X_test, y_train, _ = train_test_split(x, y, test_size=0.5, random_state=42) dtrain = DMatrix(X_train, label=y_train) param = {'objective': 'multi:softmax', 'num_class': 3} bst = train_xgb(param, dtrain, 10) initial_type = [('float_input', FloatTensorType([None, X_train.shape[1]]))] try: onx = convert_xgboost_booster(bst, "name", initial_types=initial_type, target_opset=12) cont = True except AssertionError as e: print("XGBoost is too recent or onnxmltools too old.", e) cont = False if cont: sess = rt.InferenceSession(onx.SerializeToString()) input_name = sess.get_inputs()[0].name label_name = sess.get_outputs()[0].name pred_onx = sess.run( [label_name], {input_name: X_test.astype(numpy.float32)})[0] print(pred_onx) .. rst-class:: sphx-glr-script-out .. code-block:: none [0 0 1 1 0 1 0 1 0 1 0 0 1 1 1 0 0 1 1 1 1 0 0 1 0 0 0 1 1 1 0 1 1 0 1 1 1 0 1 1 1 0 0 1 1 0 0 0 1 0] .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 9.343 seconds) .. _sphx_glr_download_auto_tutorial_plot_gexternal_xgboost.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_gexternal_xgboost.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_gexternal_xgboost.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_