.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/plot_custom_parser.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_plot_custom_parser.py: .. _l-custom-parser: When a custom model is neither a classifier nor a regressor =========================================================== *scikit-learn*'s API specifies that a regressor produces one outputs and a classifier produces two outputs, predicted labels and probabilities. The goal here is to add a third result which tells if the probability is above a given threshold. That's implemented in method *validate*. .. contents:: :local: Iris and scoring ++++++++++++++++ A new class is created, it trains any classifier and implements the method *validate* mentioned above. .. GENERATED FROM PYTHON SOURCE LINES 26-87 .. code-block:: default import inspect import numpy as np import skl2onnx import onnx import sklearn from sklearn.base import ClassifierMixin, BaseEstimator, clone from sklearn.datasets import load_iris from sklearn.linear_model import LogisticRegression from sklearn.model_selection import train_test_split from skl2onnx import update_registered_converter import os from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer import onnxruntime as rt from onnxconverter_common.onnx_ops import ( apply_identity, apply_cast, apply_greater ) from skl2onnx import to_onnx, get_model_alias from skl2onnx.proto import onnx_proto from skl2onnx.common._registration import get_shape_calculator from skl2onnx.common.data_types import FloatTensorType, Int64TensorType import matplotlib.pyplot as plt class ValidatorClassifier(BaseEstimator, ClassifierMixin): def __init__(self, estimator=None, threshold=0.75): ClassifierMixin.__init__(self) BaseEstimator.__init__(self) if estimator is None: estimator = LogisticRegression(solver='liblinear') self.estimator = estimator self.threshold = threshold def fit(self, X, y, sample_weight=None): sig = inspect.signature(self.estimator.fit) if 'sample_weight' in sig.parameters: self.estimator_ = clone(self.estimator).fit( X, y, sample_weight=sample_weight) else: self.estimator_ = clone(self.estimator).fit(X, y) return self def predict(self, X): return self.estimator_.predict(X) def predict_proba(self, X): return self.estimator_.predict_proba(X) def validate(self, X): pred = self.predict_proba(X) mx = pred.max(axis=1) return (mx >= self.threshold) * 1 data = load_iris() X, y = data.data, data.target X_train, X_test, y_train, y_test = train_test_split(X, y) model = ValidatorClassifier() model.fit(X_train, y_train) .. raw:: html
ValidatorClassifier(estimator=LogisticRegression(solver='liblinear'))
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


.. GENERATED FROM PYTHON SOURCE LINES 88-91 Let's now measure the indicator which tells if the probability of a prediction is above a threshold. .. GENERATED FROM PYTHON SOURCE LINES 91-94 .. code-block:: default print(model.validate(X_test)) .. rst-class:: sphx-glr-script-out .. code-block:: none [1 1 1 0 0 1 1 1 1 1 0 0 1 0 1 1 0 0 0 1 0 0 1 1 1 0 0 0 1 0 1 0 0 1 0 1 0 1] .. GENERATED FROM PYTHON SOURCE LINES 95-101 Conversion to ONNX +++++++++++++++++++ The conversion fails for a new model because the library does not know any converter associated to this new model. .. GENERATED FROM PYTHON SOURCE LINES 101-108 .. code-block:: default try: to_onnx(model, X_train[:1].astype(np.float32), target_opset=12) except RuntimeError as e: print(e) .. rst-class:: sphx-glr-script-out .. code-block:: none Unable to find a shape calculator for type ''. It usually means the pipeline being converted contains a transformer or a predictor with no corresponding converter implemented in sklearn-onnx. If the converted is implemented in another library, you need to register the converted so that it can be used by sklearn-onnx (function update_registered_converter). If the model is not yet covered by sklearn-onnx, you may raise an issue to https://github.com/onnx/sklearn-onnx/issues to get the converter implemented or even contribute to the project. If the model is a custom model, a new converter must be implemented. Examples can be found in the gallery. .. GENERATED FROM PYTHON SOURCE LINES 109-115 Custom converter ++++++++++++++++ We reuse some pieces of code from :ref:`l-custom-model`. The shape calculator defines the shape of every output of the converted model. .. GENERATED FROM PYTHON SOURCE LINES 115-132 .. code-block:: default def validator_classifier_shape_calculator(operator): input0 = operator.inputs[0] # inputs in ONNX graph outputs = operator.outputs # outputs in ONNX graph op = operator.raw_operator # scikit-learn model (mmust be fitted) if len(outputs) != 3: raise RuntimeError("3 outputs expected not {}.".format(len(outputs))) N = input0.type.shape[0] # number of observations C = op.estimator_.classes_.shape[0] # dimension of outputs outputs[0].type = Int64TensorType([N]) # label outputs[1].type = FloatTensorType([N, C]) # probabilities outputs[2].type = Int64TensorType([C]) # validation .. GENERATED FROM PYTHON SOURCE LINES 133-134 Then the converter. .. GENERATED FROM PYTHON SOURCE LINES 134-189 .. code-block:: default def validator_classifier_converter(scope, operator, container): outputs = operator.outputs # outputs in ONNX graph op = operator.raw_operator # scikit-learn model (mmust be fitted) # We reuse existing converter and declare it # as a local operator. model = op.estimator_ alias = get_model_alias(type(model)) val_op = scope.declare_local_operator(alias, model) val_op.inputs = operator.inputs # We add an intermediate outputs. val_label = scope.declare_local_variable('val_label', Int64TensorType()) val_prob = scope.declare_local_variable('val_prob', FloatTensorType()) val_op.outputs.append(val_label) val_op.outputs.append(val_prob) # We adjust the output of the submodel. shape_calc = get_shape_calculator(alias) shape_calc(val_op) # We now handle the validation. val_max = scope.get_unique_variable_name('val_max') if container.target_opset >= 18: axis_name = scope.get_unique_variable_name('axis') container.add_initializer( axis_name, onnx_proto.TensorProto.INT64, [1], [1]) container.add_node( 'ReduceMax', [val_prob.full_name, axis_name], val_max, name=scope.get_unique_operator_name('ReduceMax'), keepdims=0) else: container.add_node( 'ReduceMax', val_prob.full_name, val_max, name=scope.get_unique_operator_name('ReduceMax'), axes=[1], keepdims=0) th_name = scope.get_unique_variable_name('threshold') container.add_initializer( th_name, onnx_proto.TensorProto.FLOAT, [1], [op.threshold]) val_bin = scope.get_unique_variable_name('val_bin') apply_greater(scope, [val_max, th_name], val_bin, container) val_val = scope.get_unique_variable_name('validate') apply_cast(scope, val_bin, val_val, container, to=onnx_proto.TensorProto.INT64) # We finally link the intermediate output to the shared converter. apply_identity(scope, val_label.full_name, outputs[0].full_name, container) apply_identity(scope, val_prob.full_name, outputs[1].full_name, container) apply_identity(scope, val_val, outputs[2].full_name, container) .. GENERATED FROM PYTHON SOURCE LINES 190-191 Then the registration. .. GENERATED FROM PYTHON SOURCE LINES 191-197 .. code-block:: default update_registered_converter(ValidatorClassifier, 'CustomValidatorClassifier', validator_classifier_shape_calculator, validator_classifier_converter) .. GENERATED FROM PYTHON SOURCE LINES 198-199 And conversion... .. GENERATED FROM PYTHON SOURCE LINES 199-206 .. code-block:: default try: to_onnx(model, X_test[:1].astype(np.float32), target_opset=12) except RuntimeError as e: print(e) .. rst-class:: sphx-glr-script-out .. code-block:: none 3 outputs expected not 2. .. GENERATED FROM PYTHON SOURCE LINES 207-214 It fails because the library expected the model to behave like a classifier which produces two outputs. We need to add a custom parser to tell the library this model produces three outputs. Custom parser +++++++++++++ .. GENERATED FROM PYTHON SOURCE LINES 214-234 .. code-block:: default def validator_classifier_parser(scope, model, inputs, custom_parsers=None): alias = get_model_alias(type(model)) this_operator = scope.declare_local_operator(alias, model) # inputs this_operator.inputs.append(inputs[0]) # outputs val_label = scope.declare_local_variable('val_label', Int64TensorType()) val_prob = scope.declare_local_variable('val_prob', FloatTensorType()) val_val = scope.declare_local_variable('val_val', Int64TensorType()) this_operator.outputs.append(val_label) this_operator.outputs.append(val_prob) this_operator.outputs.append(val_val) # end return this_operator.outputs .. GENERATED FROM PYTHON SOURCE LINES 235-236 Registration. .. GENERATED FROM PYTHON SOURCE LINES 236-243 .. code-block:: default update_registered_converter(ValidatorClassifier, 'CustomValidatorClassifier', validator_classifier_shape_calculator, validator_classifier_converter, parser=validator_classifier_parser) .. GENERATED FROM PYTHON SOURCE LINES 244-245 And conversion again. .. GENERATED FROM PYTHON SOURCE LINES 245-249 .. code-block:: default model_onnx = to_onnx(model, X_test[:1].astype(np.float32), target_opset=12) .. GENERATED FROM PYTHON SOURCE LINES 250-254 Final test ++++++++++ We need now to check the results are the same with ONNX. .. GENERATED FROM PYTHON SOURCE LINES 254-270 .. code-block:: default X32 = X_test[:5].astype(np.float32) sess = rt.InferenceSession(model_onnx.SerializeToString()) results = sess.run(None, {'X': X32}) print("--labels--") print("sklearn", model.predict(X32)) print("onnx", results[0]) print("--probabilities--") print("sklearn", model.predict_proba(X32)) print("onnx", results[1]) print("--validation--") print("sklearn", model.validate(X32)) print("onnx", results[2]) .. rst-class:: sphx-glr-script-out .. code-block:: none --labels-- sklearn [2 0 1 1 1] onnx [2 0 1 1 1] --probabilities-- sklearn [[7.07999425e-04 2.48572306e-01 7.50719694e-01] [8.21007140e-01 1.78895436e-01 9.74239685e-05] [2.10535503e-02 7.60262548e-01 2.18683902e-01] [1.07188246e-02 6.36087208e-01 3.53193967e-01] [8.79788000e-03 6.77701629e-01 3.13500491e-01]] onnx [[7.07984611e-04 2.48572275e-01 7.50719726e-01] [8.21007073e-01 1.78895399e-01 9.74424402e-05] [2.10535377e-02 7.60262549e-01 2.18683958e-01] [1.07187675e-02 6.36087179e-01 3.53194058e-01] [8.79786070e-03 6.77701533e-01 3.13500643e-01]] --validation-- sklearn [1 1 1 0 0] onnx [1 1 1 0 0] .. GENERATED FROM PYTHON SOURCE LINES 271-275 It looks good. Display the ONNX graph ++++++++++++++++++++++ .. GENERATED FROM PYTHON SOURCE LINES 275-289 .. code-block:: default pydot_graph = GetPydotGraph( model_onnx.graph, name=model_onnx.graph.name, rankdir="TB", node_producer=GetOpNodeProducer( "docstring", color="yellow", fillcolor="yellow", style="filled")) pydot_graph.write_dot("validator_classifier.dot") os.system('dot -O -Gdpi=300 -Tpng validator_classifier.dot') image = plt.imread("validator_classifier.dot.png") fig, ax = plt.subplots(figsize=(40, 20)) ax.imshow(image) ax.axis('off') .. image-sg:: /auto_examples/images/sphx_glr_plot_custom_parser_001.png :alt: plot custom parser :srcset: /auto_examples/images/sphx_glr_plot_custom_parser_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none (-0.5, 2495.5, 4934.5, -0.5) .. GENERATED FROM PYTHON SOURCE LINES 290-291 **Versions used for this example** .. GENERATED FROM PYTHON SOURCE LINES 291-297 .. code-block:: default print("numpy:", np.__version__) print("scikit-learn:", sklearn.__version__) print("onnx: ", onnx.__version__) print("onnxruntime: ", rt.__version__) print("skl2onnx: ", skl2onnx.__version__) .. rst-class:: sphx-glr-script-out .. code-block:: none numpy: 1.23.5 scikit-learn: 1.2.2 onnx: 1.13.1 onnxruntime: 1.14.1 skl2onnx: 1.14.0 .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 9.047 seconds) .. _sphx_glr_download_auto_examples_plot_custom_parser.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_custom_parser.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_custom_parser.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_