.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "gyexamples/plot_orttraining_nn_gpu_fwbw.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_gyexamples_plot_orttraining_nn_gpu_fwbw.py: .. _l-orttraining-nn-gpu-fwbw: Forward backward on a neural network on GPU =========================================== This example leverages example :ref:`l-orttraining-linreg-gpu` to train a neural network from :epkg:`scikit-learn` on GPU. The code uses the same code introduced in :ref:`l-orttraining-linreg-fwbw`. .. contents:: :local: A neural network with scikit-learn ++++++++++++++++++++++++++++++++++ .. GENERATED FROM PYTHON SOURCE LINES 19-51 .. code-block:: default import warnings import numpy from pandas import DataFrame from onnxruntime import get_device from sklearn.datasets import make_regression from sklearn.model_selection import train_test_split from sklearn.neural_network import MLPRegressor from sklearn.metrics import mean_squared_error from onnxcustom.plotting.plotting_onnx import plot_onnxs from mlprodict.onnx_conv import to_onnx from onnxcustom.utils.orttraining_helper import get_train_initializer from onnxcustom.utils.onnx_helper import onnx_rename_weights from onnxcustom.training.optimizers_partial import ( OrtGradientForwardBackwardOptimizer) X, y = make_regression(1000, n_features=10, bias=2) X = X.astype(numpy.float32) y = y.astype(numpy.float32) X_train, X_test, y_train, y_test = train_test_split(X, y) nn = MLPRegressor(hidden_layer_sizes=(10, 10), max_iter=100, solver='sgd', learning_rate_init=5e-5, n_iter_no_change=1000, batch_size=10, alpha=0, momentum=0, nesterovs_momentum=False) with warnings.catch_warnings(): warnings.simplefilter('ignore') nn.fit(X_train, y_train) print(nn.loss_curve_) .. rst-class:: sphx-glr-script-out .. code-block:: none [16672.610546875, 16560.276959635416, 15882.732776692708, 10563.512243652343, 5050.831483256022, 574.1751454671224, 313.5375997924805, 232.1086828104655, 182.2064368693034, 149.1921142578125, 117.04410995483398, 87.2079758199056, 64.2483705774943, 45.84810812632243, 32.44823532104492, 23.301406904856364, 17.20445534388224, 13.230918553670248, 10.854460395177206, 9.229736528396607, 8.117675099372864, 7.205892693201701, 6.558077507019043, 6.035153317848842, 5.553128098249435, 5.2076671846707665, 4.851039233207703, 4.597374379634857, 4.337599027951558, 4.093200965722402, 3.9451861921946207, 3.7354650489489236, 3.570521529515584, 3.4363346870740257, 3.2899771054585774, 3.157844088872274, 3.0404599777857464, 2.916410984992981, 2.8262819465001425, 2.720105598370234, 2.6255326795578005, 2.5208125074704486, 2.434456414381663, 2.3487924567858376, 2.274977253675461, 2.212032690842946, 2.1391211891174318, 2.0823371533552804, 2.0334598263104757, 1.957404642502467, 1.9217858338356018, 1.8669374255339304, 1.812948048512141, 1.779149255355199, 1.7340156320730846, 1.6899153021971385, 1.6511367722352346, 1.6287659740447997, 1.5881430466969808, 1.5528664668401082, 1.5274977441628774, 1.4923988771438599, 1.472212245464325, 1.4330989170074462, 1.4177746669451396, 1.3770861375331878, 1.3351058677832286, 1.3219218897819518, 1.2960139457384745, 1.275530304312706, 1.2432468036810558, 1.2329948961734771, 1.2054487474759419, 1.191846824089686, 1.1775798761844636, 1.1554956610997518, 1.1317561089992523, 1.1128112709522247, 1.0979741044839224, 1.0864841641982397, 1.0655990505218507, 1.05265309492747, 1.0373732882738114, 1.0158864573637645, 1.0067336962620417, 0.999214769800504, 0.9788467444976171, 0.9639906392494837, 0.9506902368863424, 0.9434160772959391, 0.9244991546869278, 0.9220002925395966, 0.9021422656377157, 0.8950206072131792, 0.8875000888109207, 0.876043464342753, 0.8623660150170326, 0.8455535284678142, 0.8427771917978922, 0.8288381878534953] .. GENERATED FROM PYTHON SOURCE LINES 52-53 Score: .. GENERATED FROM PYTHON SOURCE LINES 53-57 .. code-block:: default print(f"mean_squared_error={mean_squared_error(y_test, nn.predict(X_test))!r}") .. rst-class:: sphx-glr-script-out .. code-block:: none mean_squared_error=2.1054227 .. GENERATED FROM PYTHON SOURCE LINES 58-60 Conversion to ONNX ++++++++++++++++++ .. GENERATED FROM PYTHON SOURCE LINES 60-64 .. code-block:: default onx = to_onnx(nn, X_train[:1].astype(numpy.float32), target_opset=15) plot_onnxs(onx) .. image-sg:: /gyexamples/images/sphx_glr_plot_orttraining_nn_gpu_fwbw_001.png :alt: plot orttraining nn gpu fwbw :srcset: /gyexamples/images/sphx_glr_plot_orttraining_nn_gpu_fwbw_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 65-66 Initializers to train .. GENERATED FROM PYTHON SOURCE LINES 66-70 .. code-block:: default weights = list(sorted(get_train_initializer(onx))) print(weights) .. rst-class:: sphx-glr-script-out .. code-block:: none ['coefficient', 'coefficient1', 'coefficient2', 'intercepts', 'intercepts1', 'intercepts2'] .. GENERATED FROM PYTHON SOURCE LINES 71-74 Training graph with forward backward ++++++++++++++++++++++++++++++++++++ .. GENERATED FROM PYTHON SOURCE LINES 74-78 .. code-block:: default device = "cuda" if get_device().upper() == 'GPU' else 'cpu' print(f"device={device!r} get_device()={get_device()!r}") .. rst-class:: sphx-glr-script-out .. code-block:: none device='cpu' get_device()='CPU' .. GENERATED FROM PYTHON SOURCE LINES 79-86 The training session. The first instructions fails for an odd reason as the class :epkg:`TrainingAgent` expects to find the list of weights to train in alphabetical order. That means the list `onx.graph.initializer` must be sorted by alphabetical order of their names otherwise the process could crash unless it is caught earlier with the following exception. .. GENERATED FROM PYTHON SOURCE LINES 86-95 .. code-block:: default try: train_session = OrtGradientForwardBackwardOptimizer( onx, device=device, verbose=1, warm_start=False, max_iter=100, batch_size=10) train_session.fit(X, y) except ValueError as e: print(e) .. rst-class:: sphx-glr-script-out .. code-block:: none List of weights to train must be sorted but ['coefficient', 'intercepts', 'coefficient1', 'intercepts1', 'coefficient2', 'intercepts2'] is not. You shoud use function onnx_rename_weights to do that before calling this class. .. GENERATED FROM PYTHON SOURCE LINES 96-100 Function :func:`onnx_rename_weights ` does not change the order of the initializer but renames them. Then class :epkg:`TrainingAgent` may work. .. GENERATED FROM PYTHON SOURCE LINES 100-107 .. code-block:: default onx = onnx_rename_weights(onx) train_session = OrtGradientForwardBackwardOptimizer( onx, device=device, verbose=1, learning_rate=5e-5, warm_start=False, max_iter=100, batch_size=10) train_session.fit(X, y) .. rst-class:: sphx-glr-script-out .. code-block:: none 0%| | 0/100 [00:00 .. GENERATED FROM PYTHON SOURCE LINES 122-124 The convergence rate is different but both classes do not update the learning the same way. .. GENERATED FROM PYTHON SOURCE LINES 124-127 .. code-block:: default # import matplotlib.pyplot as plt # plt.show() .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 26.585 seconds) .. _sphx_glr_download_gyexamples_plot_orttraining_nn_gpu_fwbw.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_orttraining_nn_gpu_fwbw.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_orttraining_nn_gpu_fwbw.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_