.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/plot_usparse_xgboost.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code or to run this example in your browser via Binder .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_plot_usparse_xgboost.py: .. _example-sparse-tfidf: TfIdf and sparse matrices ========================= .. index:: xgboost, lightgbm, sparse, ensemble `TfidfVectorizer `_ usually creates sparse data. If the data is sparse enough, matrices usually stays as sparse all along the pipeline until the predictor is trained. Sparse matrices do not consider null and missing values as they are not present in the datasets. Because some predictors do the difference, this ambiguity may introduces discrepencies when converter into ONNX. This example looks into several configurations. .. contents:: :local: Imports, setups +++++++++++++++ All imports. It also registered onnx converters for :epgk:`xgboost` and :epkg:`lightgbm`. .. GENERATED FROM PYTHON SOURCE LINES 27-64 .. code-block:: default import warnings import numpy import pandas from tqdm import tqdm from sklearn.compose import ColumnTransformer from sklearn.datasets import load_iris from sklearn.pipeline import Pipeline from sklearn.preprocessing import StandardScaler from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer from sklearn.experimental import ( # noqa enable_hist_gradient_boosting) # noqa from sklearn.ensemble import ( RandomForestClassifier, HistGradientBoostingClassifier) from xgboost import XGBClassifier from lightgbm import LGBMClassifier from skl2onnx.common.data_types import FloatTensorType, StringTensorType from skl2onnx import to_onnx, update_registered_converter from skl2onnx.sklapi import CastTransformer, ReplaceTransformer from skl2onnx.common.shape_calculator import ( calculate_linear_classifier_output_shapes) from onnxmltools.convert.xgboost.operator_converters.XGBoost import ( convert_xgboost) from onnxmltools.convert.lightgbm.operator_converters.LightGbm import ( convert_lightgbm) from mlprodict.onnxrt import OnnxInference update_registered_converter( XGBClassifier, 'XGBoostXGBClassifier', calculate_linear_classifier_output_shapes, convert_xgboost, options={'nocl': [True, False], 'zipmap': [True, False, 'columns']}) update_registered_converter( LGBMClassifier, 'LightGbmLGBMClassifier', calculate_linear_classifier_output_shapes, convert_lightgbm, options={'nocl': [True, False], 'zipmap': [True, False]}) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none somewhereonnxcustom_39_std/_venv/lib/python3.9/site-packages/sklearn/experimental/enable_hist_gradient_boosting.py:16: UserWarning: Since version 1.0, it is not needed to import enable_hist_gradient_boosting anymore. HistGradientBoostingClassifier and HistGradientBoostingRegressor are now stable and can be normally imported from sklearn.ensemble. warnings.warn( .. GENERATED FROM PYTHON SOURCE LINES 65-69 Artificial datasets +++++++++++++++++++++++++++ Iris + a text column. .. GENERATED FROM PYTHON SOURCE LINES 69-86 .. code-block:: default cst = ['class zero', 'class one', 'class two'] data = load_iris() X = data.data[:, :2] y = data.target df = pandas.DataFrame(X) df["text"] = [cst[i] for i in y] ind = numpy.arange(X.shape[0]) numpy.random.shuffle(ind) X = X[ind, :].copy() y = y[ind].copy() .. GENERATED FROM PYTHON SOURCE LINES 87-93 Train ensemble after sparse +++++++++++++++++++++++++++ The example use the Iris datasets with artifical text datasets preprocessed with a tf-idf. `sparse_threshold=1.` avoids sparse matrices to be converted into dense matrices. .. GENERATED FROM PYTHON SOURCE LINES 93-224 .. code-block:: default def make_pipelines(df_train, y_train, models=None, sparse_threshold=1., replace_nan=False, insert_replace=False, verbose=False): if models is None: models = [ RandomForestClassifier, HistGradientBoostingClassifier, XGBClassifier, LGBMClassifier] pipes = [] for model in tqdm(models): if model == HistGradientBoostingClassifier: kwargs = dict(max_iter=5) elif model == XGBClassifier: kwargs = dict(n_estimators=5, use_label_encoder=False) else: kwargs = dict(n_estimators=5) if insert_replace: pipe = Pipeline([ ('union', ColumnTransformer([ ('scale1', StandardScaler(), [0, 1]), ('subject', Pipeline([ ('count', CountVectorizer()), ('tfidf', TfidfTransformer()), ('repl', ReplaceTransformer()), ]), "text"), ], sparse_threshold=sparse_threshold)), ('cast', CastTransformer()), ('cls', model(max_depth=3, **kwargs)), ]) else: pipe = Pipeline([ ('union', ColumnTransformer([ ('scale1', StandardScaler(), [0, 1]), ('subject', Pipeline([ ('count', CountVectorizer()), ('tfidf', TfidfTransformer()) ]), "text"), ], sparse_threshold=sparse_threshold)), ('cast', CastTransformer()), ('cls', model(max_depth=3, **kwargs)), ]) try: pipe.fit(df_train, y_train) except TypeError as e: obs = dict(model=model.__name__, pipe=pipe, error=e) pipes.append(obs) continue options = {model: {'zipmap': False}} if replace_nan: options[TfidfTransformer] = {'nan': True} # convert with warnings.catch_warnings(record=False): warnings.simplefilter("ignore", (FutureWarning, UserWarning)) model_onnx = to_onnx( pipe, initial_types=[('input', FloatTensorType([None, 2])), ('text', StringTensorType([None, 1]))], target_opset=12, options=options) with open('model.onnx', 'wb') as f: f.write(model_onnx.SerializeToString()) oinf = OnnxInference(model_onnx) inputs = {"input": df[[0, 1]].values.astype(numpy.float32), "text": df[["text"]].values} pred_onx = oinf.run(inputs) diff = numpy.abs( pred_onx['probabilities'].ravel() - pipe.predict_proba(df).ravel()).sum() if verbose: def td(a): if hasattr(a, 'todense'): b = a.todense() ind = set(a.indices) for i in range(b.shape[1]): if i not in ind: b[0, i] = numpy.nan return b return a oinf = OnnxInference(model_onnx) pred_onx2 = oinf.run(inputs) diff2 = numpy.abs( pred_onx2['probabilities'].ravel() - pipe.predict_proba(df).ravel()).sum() if diff > 0.1: for i, (l1, l2) in enumerate( zip(pipe.predict_proba(df), pred_onx['probabilities'])): d = numpy.abs(l1 - l2).sum() if verbose and d > 0.1: print("\nDISCREPENCY DETAILS") print(d, i, l1, l2) pre = pipe.steps[0][-1].transform(df) print("idf", pre[i].dtype, td(pre[i])) pre2 = pipe.steps[1][-1].transform(pre) print("cas", pre2[i].dtype, td(pre2[i])) inter = oinf.run(inputs, intermediate=True) onx = inter['tfidftr_norm'] print("onx", onx.dtype, onx[i]) onx = inter['variable3'] obs = dict(model=model.__name__, discrepencies=diff, model_onnx=model_onnx, pipe=pipe) if verbose: obs['discrepency2'] = diff2 pipes.append(obs) return pipes data_sparse = make_pipelines(df, y) stat = pandas.DataFrame(data_sparse).drop(['model_onnx', 'pipe'], axis=1) if 'error' in stat.columns: print(stat.drop('error', axis=1)) stat .. rst-class:: sphx-glr-script-out Out: .. code-block:: none 0%| | 0/4 [00:00
model discrepencies error
0 RandomForestClassifier 0.535242 NaN
1 HistGradientBoostingClassifier NaN A sparse matrix was passed, but dense data is ...
2 XGBClassifier 19.539383 NaN
3 LGBMClassifier 0.000007 NaN


.. GENERATED FROM PYTHON SOURCE LINES 225-231 Sparse data hurts. Dense data ++++++++++ Let's replace sparse data with dense by using `sparse_threshold=0.` .. GENERATED FROM PYTHON SOURCE LINES 231-239 .. code-block:: default data_dense = make_pipelines(df, y, sparse_threshold=0.) stat = pandas.DataFrame(data_dense).drop(['model_onnx', 'pipe'], axis=1) if 'error' in stat.columns: print(stat.drop('error', axis=1)) stat .. rst-class:: sphx-glr-script-out Out: .. code-block:: none 0%| | 0/4 [00:00
model discrepencies
0 RandomForestClassifier 0.240004
1 HistGradientBoostingClassifier 0.000006
2 XGBClassifier 0.000005
3 LGBMClassifier 0.000007


.. GENERATED FROM PYTHON SOURCE LINES 240-242 This is much better. Let's compare how the preprocessing applies on the data. .. GENERATED FROM PYTHON SOURCE LINES 242-249 .. code-block:: default print("sparse") print(data_sparse[-1]['pipe'].steps[0][-1].transform(df)[:2]) print() print("dense") print(data_dense[-1]['pipe'].steps[0][-1].transform(df)[:2]) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none sparse (0, 0) -0.9006811702978088 (0, 1) 1.019004351971607 (0, 2) 0.4323732931220851 (0, 5) 0.9016947018779491 (1, 0) -1.1430169111851105 (1, 1) -0.13197947932162468 (1, 2) 0.4323732931220851 (1, 5) 0.9016947018779491 dense [[-0.90068117 1.01900435 0.43237329 0. 0. 0.9016947 ] [-1.14301691 -0.13197948 0.43237329 0. 0. 0.9016947 ]] .. GENERATED FROM PYTHON SOURCE LINES 250-269 This shows `RandomForestClassifier `_, `XGBClassifier `_ do not process the same way sparse and dense matrix as opposed to `LGBMClassifier `_. And `HistGradientBoostingClassifier `_ fails. Dense data with nan +++++++++++++++++++ Let's keep sparse data in the scikit-learn pipeline but replace null values by nan in the onnx graph. .. GENERATED FROM PYTHON SOURCE LINES 269-277 .. code-block:: default data_dense = make_pipelines(df, y, sparse_threshold=1., replace_nan=True) stat = pandas.DataFrame(data_dense).drop(['model_onnx', 'pipe'], axis=1) if 'error' in stat.columns: print(stat.drop('error', axis=1)) stat .. rst-class:: sphx-glr-script-out Out: .. code-block:: none 0%| | 0/4 [00:00
model discrepencies error
0 RandomForestClassifier 53.586598 NaN
1 HistGradientBoostingClassifier NaN A sparse matrix was passed, but dense data is ...
2 XGBClassifier 0.000005 NaN
3 LGBMClassifier 0.000007 NaN


.. GENERATED FROM PYTHON SOURCE LINES 278-287 Dense, 0 replaced by nan ++++++++++++++++++++++++ Instead of using a specific options to replace null values into nan values, a custom transformer called ReplaceTransformer is explicitely inserted into the pipeline. A new converter is added to the list of supported models. It is equivalent to the previous options except it is more explicit. .. GENERATED FROM PYTHON SOURCE LINES 287-295 .. code-block:: default data_dense = make_pipelines(df, y, sparse_threshold=1., replace_nan=False, insert_replace=True) stat = pandas.DataFrame(data_dense).drop(['model_onnx', 'pipe'], axis=1) if 'error' in stat.columns: print(stat.drop('error', axis=1)) stat .. rst-class:: sphx-glr-script-out Out: .. code-block:: none 0%| | 0/4 [00:00
model discrepencies error
0 RandomForestClassifier 52.319427 NaN
1 HistGradientBoostingClassifier NaN A sparse matrix was passed, but dense data is ...
2 XGBClassifier 0.000005 NaN
3 LGBMClassifier 0.000007 NaN


.. GENERATED FROM PYTHON SOURCE LINES 296-302 Conclusion ++++++++++ Unless dense arrays are used, because :epkg:`onnxruntime` ONNX does not support sparse yet, the conversion needs to be tuned depending on the model which follows the TfIdf preprocessing. .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 5.337 seconds) .. _sphx_glr_download_auto_examples_plot_usparse_xgboost.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: binder-badge .. image:: images/binder_badge_logo.svg :target: https://mybinder.org/v2/gh/sdpython/onnxcustom/master?urlpath=lab/tree/notebooks/auto_examples/plot_usparse_xgboost.ipynb :alt: Launch binder :width: 150 px .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_usparse_xgboost.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_usparse_xgboost.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_