.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/plot_tfidfvectorizer.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_plot_tfidfvectorizer.py: .. _l-example-tfidfvectorizer: TfIdfVectorizer with ONNX ========================= This example is inspired from the following example: `Column Transformer with Heterogeneous Data Sources `_ which builds a pipeline to classify text. .. contents:: :local: Train a pipeline with TfidfVectorizer +++++++++++++++++++++++++++++++++++++ It replicates the same pipeline taken from *scikit-learn* documentation but reduces it to the part ONNX actually supports without implementing a custom converter. Let's get the data. .. GENERATED FROM PYTHON SOURCE LINES 26-64 .. code-block:: default import matplotlib.pyplot as plt import os from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer import numpy import onnxruntime as rt from skl2onnx.common.data_types import StringTensorType from skl2onnx import convert_sklearn import numpy as np from sklearn.base import BaseEstimator, TransformerMixin from sklearn.datasets import fetch_20newsgroups try: from sklearn.datasets._twenty_newsgroups import ( strip_newsgroup_footer, strip_newsgroup_quoting) except ImportError: # scikit-learn < 0.24 from sklearn.datasets.twenty_newsgroups import ( strip_newsgroup_footer, strip_newsgroup_quoting) from sklearn.decomposition import TruncatedSVD from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.pipeline import Pipeline from sklearn.compose import ColumnTransformer from sklearn.metrics import classification_report from sklearn.linear_model import LogisticRegression # limit the list of categories to make running this example faster. categories = ['alt.atheism', 'talk.religion.misc'] train = fetch_20newsgroups(random_state=1, subset='train', categories=categories, ) test = fetch_20newsgroups(random_state=1, subset='test', categories=categories, ) .. GENERATED FROM PYTHON SOURCE LINES 65-68 The first transform extract two fields from the data. We take it out form the pipeline and assume the data is defined by two text columns. .. GENERATED FROM PYTHON SOURCE LINES 68-103 .. code-block:: default class SubjectBodyExtractor(BaseEstimator, TransformerMixin): """Extract the subject & body from a usenet post in a single pass. Takes a sequence of strings and produces a dict of sequences. Keys are `subject` and `body`. """ def fit(self, x, y=None): return self def transform(self, posts): # construct object dtype array with two columns # first column = 'subject' and second column = 'body' features = np.empty(shape=(len(posts), 2), dtype=object) for i, text in enumerate(posts): headers, _, bod = text.partition('\n\n') bod = strip_newsgroup_footer(bod) bod = strip_newsgroup_quoting(bod) features[i, 1] = bod prefix = 'Subject:' sub = '' for line in headers.split('\n'): if line.startswith(prefix): sub = line[len(prefix):] break features[i, 0] = sub return features train_data = SubjectBodyExtractor().fit_transform(train.data) test_data = SubjectBodyExtractor().fit_transform(test.data) .. GENERATED FROM PYTHON SOURCE LINES 104-106 The pipeline is almost the same except we remove the custom features. .. GENERATED FROM PYTHON SOURCE LINES 106-140 .. code-block:: default pipeline = Pipeline([ ('union', ColumnTransformer( [ ('subject', TfidfVectorizer(min_df=50, max_features=500), 0), ('body_bow', Pipeline([ ('tfidf', TfidfVectorizer()), ('best', TruncatedSVD(n_components=50)), ]), 1), # Removed from the original example as # it requires a custom converter. # ('body_stats', Pipeline([ # ('stats', TextStats()), # returns a list of dicts # ('vect', DictVectorizer()), # list of dicts -> feature matrix # ]), 1), ], transformer_weights={ 'subject': 0.8, 'body_bow': 0.5, # 'body_stats': 1.0, } )), # Use a LogisticRegression classifier on the combined features. # Instead of LinearSVC (not fully ready in onnxruntime). ('logreg', LogisticRegression()), ]) pipeline.fit(train_data, train.target) print(classification_report(pipeline.predict(test_data), test.target)) .. rst-class:: sphx-glr-script-out .. code-block:: none precision recall f1-score support 0 0.69 0.78 0.73 285 1 0.75 0.66 0.70 285 accuracy 0.72 570 macro avg 0.72 0.72 0.71 570 weighted avg 0.72 0.72 0.71 570 .. GENERATED FROM PYTHON SOURCE LINES 141-150 ONNX conversion +++++++++++++++ It is difficult to replicate the exact same tokenizer behaviour if the tokeniser comes from space, gensim or nltk. The default one used by *scikit-learn* uses regular expressions and is currently being implementing. The current implementation only considers a list of separators which can is defined in variable *seps*. .. GENERATED FROM PYTHON SOURCE LINES 150-166 .. code-block:: default seps = { TfidfVectorizer: { "separators": [ ' ', '.', '\\?', ',', ';', ':', '!', '\\(', '\\)', '\n', '"', "'", "-", "\\[", "\\]", "@" ] } } model_onnx = convert_sklearn( pipeline, "tfidf", initial_types=[("input", StringTensorType([None, 2]))], options=seps, target_opset=12) .. GENERATED FROM PYTHON SOURCE LINES 167-168 And save. .. GENERATED FROM PYTHON SOURCE LINES 168-171 .. code-block:: default with open("pipeline_tfidf.onnx", "wb") as f: f.write(model_onnx.SerializeToString()) .. GENERATED FROM PYTHON SOURCE LINES 172-173 Predictions with onnxruntime. .. GENERATED FROM PYTHON SOURCE LINES 173-181 .. code-block:: default sess = rt.InferenceSession("pipeline_tfidf.onnx") print('---', train_data[0]) inputs = {'input': train_data[:1]} pred_onx = sess.run(None, inputs) print("predict", pred_onx[0]) print("predict_proba", pred_onx[1]) .. rst-class:: sphx-glr-script-out .. code-block:: none --- [" Re: Jews can't hide from keith@cco." 'Deletions...\n\nSo, you consider the german poster\'s remark anti-semitic? Perhaps you\nimply that anyone in Germany who doesn\'t agree with israely policy in a\nnazi? Pray tell, how does it even qualify as "casual anti-semitism"? \nIf the term doesn\'t apply, why then bring it up?\n\nYour own bigotry is shining through. \n-- '] predict [1] predict_proba [{0: 0.4390333592891693, 1: 0.5609666705131531}] .. GENERATED FROM PYTHON SOURCE LINES 182-183 With *scikit-learn*: .. GENERATED FROM PYTHON SOURCE LINES 183-186 .. code-block:: default print(pipeline.predict(train_data[:1])) print(pipeline.predict_proba(train_data[:1])) .. rst-class:: sphx-glr-script-out .. code-block:: none [0] [[0.7180778 0.2819222]] .. GENERATED FROM PYTHON SOURCE LINES 187-190 There are discrepencies for this model because the tokenization is not exactly the same. This is a work in progress. .. GENERATED FROM PYTHON SOURCE LINES 192-196 Display the ONNX graph ++++++++++++++++++++++ Finally, let's see the graph converted with *sklearn-onnx*. .. GENERATED FROM PYTHON SOURCE LINES 196-211 .. code-block:: default pydot_graph = GetPydotGraph( model_onnx.graph, name=model_onnx.graph.name, rankdir="TB", node_producer=GetOpNodeProducer("docstring", color="yellow", fillcolor="yellow", style="filled")) pydot_graph.write_dot("pipeline_tfidf.dot") os.system('dot -O -Gdpi=300 -Tpng pipeline_tfidf.dot') image = plt.imread("pipeline_tfidf.dot.png") fig, ax = plt.subplots(figsize=(40, 20)) ax.imshow(image) ax.axis('off') .. image-sg:: /auto_examples/images/sphx_glr_plot_tfidfvectorizer_001.png :alt: plot tfidfvectorizer :srcset: /auto_examples/images/sphx_glr_plot_tfidfvectorizer_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none (-0.5, 3887.5, 11475.5, -0.5) .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 39.063 seconds) .. _sphx_glr_download_auto_examples_plot_tfidfvectorizer.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_tfidfvectorizer.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_tfidfvectorizer.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_