.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "gyexamples/plot_orttraining_linear_regression.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_gyexamples_plot_orttraining_linear_regression.py: .. _l-orttraining-linreg: Train a linear regression with onnxruntime-training =================================================== This example explores how :epkg:`onnxruntime-training` can be used to train a simple linear regression using a gradient descent. It compares the results with those obtained by :class:`sklearn.linear_model.SGDRegressor` .. contents:: :local: A simple linear regression with scikit-learn ++++++++++++++++++++++++++++++++++++++++++++ .. GENERATED FROM PYTHON SOURCE LINES 20-45 .. code-block:: default from pprint import pprint import numpy import onnx from pandas import DataFrame from onnxruntime import ( InferenceSession, get_device) from sklearn.datasets import make_regression from sklearn.model_selection import train_test_split from sklearn.linear_model import SGDRegressor from sklearn.neural_network import MLPRegressor from mlprodict.onnx_conv import to_onnx from onnxcustom.plotting.plotting_onnx import plot_onnxs from onnxcustom.utils.orttraining_helper import ( add_loss_output, get_train_initializer) from onnxcustom.training.optimizers import OrtGradientOptimizer X, y = make_regression(n_features=2, bias=2) X = X.astype(numpy.float32) y = y.astype(numpy.float32) X_train, X_test, y_train, y_test = train_test_split(X, y) lr = SGDRegressor(l1_ratio=0, max_iter=200, eta0=5e-2) lr.fit(X, y) print(lr.predict(X[:5])) .. rst-class:: sphx-glr-script-out .. code-block:: none [ 53.83746772 6.79334086 86.31569796 73.72210063 -79.5078181 ] .. GENERATED FROM PYTHON SOURCE LINES 46-47 The trained coefficients are: .. GENERATED FROM PYTHON SOURCE LINES 47-49 .. code-block:: default print("trained coefficients:", lr.coef_, lr.intercept_) .. rst-class:: sphx-glr-script-out .. code-block:: none trained coefficients: [67.74134798 7.91162025] [2.00017697] .. GENERATED FROM PYTHON SOURCE LINES 50-52 However this model does not show the training curve. We switch to a :class:`sklearn.neural_network.MLPRegressor`. .. GENERATED FROM PYTHON SOURCE LINES 52-62 .. code-block:: default lr = MLPRegressor(hidden_layer_sizes=tuple(), activation='identity', max_iter=200, batch_size=10, solver='sgd', alpha=0, learning_rate_init=1e-2, n_iter_no_change=200, momentum=0, nesterovs_momentum=False) lr.fit(X, y) print(lr.predict(X[:5])) .. rst-class:: sphx-glr-script-out .. code-block:: none somewhere/workspace/onnxcustom/onnxcustom_UT_39_std/_venv/lib/python3.9/site-packages/sklearn/neural_network/_multilayer_perceptron.py:679: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (200) reached and the optimization hasn't converged yet. warnings.warn( [ 53.842594 6.7937374 86.3236 73.72894 -79.51577 ] .. GENERATED FROM PYTHON SOURCE LINES 63-64 The trained coefficients are: .. GENERATED FROM PYTHON SOURCE LINES 64-66 .. code-block:: default print("trained coefficients:", lr.coefs_, lr.intercepts_) .. rst-class:: sphx-glr-script-out .. code-block:: none trained coefficients: [array([[67.74782 ], [ 7.911858]], dtype=float32)] [array([2.0000002], dtype=float32)] .. GENERATED FROM PYTHON SOURCE LINES 67-73 ONNX graph ++++++++++ Training with :epkg:`onnxruntime-training` starts with an ONNX graph which defines the model to learn. It is obtained by simply converting the previous linear regression into ONNX. .. GENERATED FROM PYTHON SOURCE LINES 73-77 .. code-block:: default onx = to_onnx(lr, X_train[:1].astype(numpy.float32), target_opset=15, black_op={'LinearRegressor'}) .. GENERATED FROM PYTHON SOURCE LINES 78-87 Choosing a loss +++++++++++++++ The training requires a loss function. By default, it is the square function but it could be the absolute error or include regularization. Function :func:`add_loss_output ` appends the loss function to the ONNX graph. .. GENERATED FROM PYTHON SOURCE LINES 87-94 .. code-block:: default onx_train = add_loss_output(onx) plot_onnxs(onx, onx_train, title=['Linear Regression', 'Linear Regression + Loss with ONNX']) .. image-sg:: /gyexamples/images/sphx_glr_plot_orttraining_linear_regression_001.png :alt: Linear Regression, Linear Regression + Loss with ONNX :srcset: /gyexamples/images/sphx_glr_plot_orttraining_linear_regression_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none array([, ], dtype=object) .. GENERATED FROM PYTHON SOURCE LINES 95-96 Let's check inference is working. .. GENERATED FROM PYTHON SOURCE LINES 96-102 .. code-block:: default sess = InferenceSession(onx_train.SerializeToString(), providers=['CPUExecutionProvider']) res = sess.run(None, {'X': X_test, 'label': y_test.reshape((-1, 1))}) print(f"onnx loss={res[0][0, 0] / X_test.shape[0]!r}") .. rst-class:: sphx-glr-script-out .. code-block:: none onnx loss=2.4012649646465434e-08 .. GENERATED FROM PYTHON SOURCE LINES 103-111 Weights +++++++ Every initializer is a set of weights which can be trained and a gradient will be computed for it. However an initializer used to modify a shape or to extract a subpart of a tensor does not need training. Let's remove them from the list of initializer to train. .. GENERATED FROM PYTHON SOURCE LINES 111-116 .. code-block:: default inits = get_train_initializer(onx) weights = {k: v for k, v in inits.items() if k != "shape_tensor"} pprint(list((k, v[0].shape) for k, v in weights.items())) .. rst-class:: sphx-glr-script-out .. code-block:: none [('coefficient', (2, 1)), ('intercepts', (1, 1))] .. GENERATED FROM PYTHON SOURCE LINES 117-119 Train on CPU or GPU if available ++++++++++++++++++++++++++++++++ .. GENERATED FROM PYTHON SOURCE LINES 119-123 .. code-block:: default device = "cuda" if get_device().upper() == 'GPU' else 'cpu' print(f"device={device!r} get_device()={get_device()!r}") .. rst-class:: sphx-glr-script-out .. code-block:: none device='cpu' get_device()='CPU' .. GENERATED FROM PYTHON SOURCE LINES 124-134 Stochastic Gradient Descent +++++++++++++++++++++++++++ The training logic is hidden in class :class:`OrtGradientOptimizer `. It follows :epkg:`scikit-learn` API (see `SGDRegressor `_. The gradient graph is not available at this stage. .. GENERATED FROM PYTHON SOURCE LINES 134-142 .. code-block:: default train_session = OrtGradientOptimizer( onx_train, list(weights), device=device, verbose=1, learning_rate=1e-2, warm_start=False, max_iter=200, batch_size=10, saved_gradient="saved_gradient.onnx") train_session.fit(X, y) .. rst-class:: sphx-glr-script-out .. code-block:: none 0%| | 0/200 [00:00 .. GENERATED FROM PYTHON SOURCE LINES 155-156 the training graph looks like the following... .. GENERATED FROM PYTHON SOURCE LINES 156-167 .. code-block:: default with open("saved_gradient.onnx.training.onnx", "rb") as f: graph = onnx.load(f) for inode, node in enumerate(graph.graph.node): if '' in node.output: for i in range(len(node.output)): if node.output[i] == "": node.output[i] = "n%d-%d" % (inode, i) plot_onnxs(graph, title='Training graph') .. image-sg:: /gyexamples/images/sphx_glr_plot_orttraining_linear_regression_003.png :alt: Training graph :srcset: /gyexamples/images/sphx_glr_plot_orttraining_linear_regression_003.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 168-175 The convergence speed is not the same but both gradient descents do not update the gradient multiplier the same way. :epkg:`onnxruntime-training` does not implement any gradient descent, it just computes the gradient. That's the purpose of :class:`OrtGradientOptimizer `. Next example digs into the implementation details. .. GENERATED FROM PYTHON SOURCE LINES 175-178 .. code-block:: default # import matplotlib.pyplot as plt # plt.show() .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 4.645 seconds) .. _sphx_glr_download_gyexamples_plot_orttraining_linear_regression.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_orttraining_linear_regression.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_orttraining_linear_regression.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_