{"cells": [{"cell_type": "markdown", "id": "61c4a48d", "metadata": {}, "source": ["# Stochastic Gradient Descent on simple function\n", "\n", "[onnxruntime-training](https://github.com/microsoft/onnxruntime) is an extension onnxruntime or more precisely the same library compiled with different settings. It provides a way to compute a gradient of a function defined by an ONNX graph."]}, {"cell_type": "code", "execution_count": 1, "id": "cb34ae6d", "metadata": {}, "outputs": [{"data": {"text/html": ["
run previous cell, wait for 2 seconds
\n", ""], "text/plain": [""]}, "execution_count": 2, "metadata": {}, "output_type": "execute_result"}], "source": ["from jyquickhelper import add_notebook_menu\n", "add_notebook_menu()"]}, {"cell_type": "code", "execution_count": 2, "id": "ef1a1dfa", "metadata": {}, "outputs": [], "source": ["%load_ext mlprodict"]}, {"cell_type": "markdown", "id": "e39f1613", "metadata": {}, "source": ["## A simple problem\n", "\n", "Let's choose a simple regression problem defined by $z = -1 - 2x + 3y + \\frac{1}{2}x^2 -\\frac{1}{3} y^2 +\\epsilon$ and we try to approximate by a function $f(x,y) = a + bx + cy + dx^2 + ey^2$. Every coefficient is determined from an optimization problem solved with a stochastic gradient descent."]}, {"cell_type": "code", "execution_count": 3, "id": "f50d8a3d", "metadata": {"scrolled": false}, "outputs": [{"data": {"text/plain": ["array([[-1. ],\n", " [-2.5 ],\n", " [ 1.6666701 ],\n", " [ 0.16667008],\n", " [ 1.6666799 ]], dtype=float32)"]}, "execution_count": 4, "metadata": {}, "output_type": "execute_result"}], "source": ["from typing import Any\n", "import numpy\n", "import mlprodict.npy.numpy_onnx_impl as npnx\n", "from mlprodict.npy import onnxnumpy_default, NDArray\n", "\n", "\n", "@onnxnumpy_default\n", "def fct(x: NDArray[(None, 2), numpy.float32]) -> NDArray[(None, 1), numpy.float32]:\n", " coef_x = numpy.array([[-2, 3]], dtype=numpy.float32) \n", " coef_x2 = numpy.array([[0.5, -0.33333]], dtype=numpy.float32)\n", " bias = numpy.array([-1], dtype=numpy.float32)\n", " poly = x * coef_x + x * x * coef_x2\n", " y = poly[:, 0] + poly[:, 1] + bias\n", " return y.reshape((-1, 1))\n", "\n", "\n", "x = numpy.array([[0, 0], [1, 0], [0, 1], [1, 1], [2, 2]], dtype=numpy.float32)\n", "fct(x)"]}, {"cell_type": "code", "execution_count": 4, "id": "0e3735c0", "metadata": {}, "outputs": [{"data": {"text/html": ["
\n", ""], "text/plain": [""]}, "execution_count": 5, "metadata": {}, "output_type": "execute_result"}], "source": ["%onnxview fct.to_onnx()"]}, {"cell_type": "code", "execution_count": 5, "id": "f5c35b29", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["opset: domain='' version=14\n", "input: name='x' type=dtype('float32') shape=(0, 2)\n", "init: name='init' type=dtype('float32') shape=(0,) -- array([ 0.5 , -0.33333], dtype=float32)\n", "init: name='init_1' type=dtype('float32') shape=(0,) -- array([-2., 3.], dtype=float32)\n", "init: name='init_2' type=dtype('int64') shape=(0,) -- array([1], dtype=int64)\n", "init: name='init_3' type=dtype('int64') shape=(0,) -- array([2], dtype=int64)\n", "init: name='init_5' type=dtype('int64') shape=(0,) -- array([0], dtype=int64)\n", "init: name='init_b10' type=dtype('float32') shape=(0,) -- array([-1.], dtype=float32)\n", "init: name='init_b11' type=dtype('int64') shape=(0,) -- array([-1, 1], dtype=int64)\n", "Mul(x, x) -> out_mul_0\n", " Mul(out_mul_0, init) -> out_mul_0_1\n", "Mul(x, init_1) -> out_mul_0_2\n", " Add(out_mul_0_2, out_mul_0_1) -> out_add_0\n", " Slice(out_add_0, init_2, init_3, init_2) -> out_sli_0\n", " Squeeze(out_sli_0, init_2) -> out_squ_0\n", " Slice(out_add_0, init_5, init_2, init_2) -> out_sli_0_1\n", " Squeeze(out_sli_0_1, init_2) -> out_squ_0_1\n", " Add(out_squ_0_1, out_squ_0) -> out_add_0_1\n", " Add(out_add_0_1, init_b10) -> out_add_0_2\n", " Reshape(out_add_0_2, init_b11) -> y\n", "output: name='y' type=dtype('float32') shape=(0, 1)\n"]}], "source": ["from mlprodict.plotting.text_plot import onnx_simple_text_plot\n", "print(onnx_simple_text_plot(fct.to_onnx()))"]}, {"cell_type": "markdown", "id": "b56814b7", "metadata": {}, "source": ["## Gradient : retropropagation\n", "\n", "Let's look into the gradient."]}, {"cell_type": "code", "execution_count": 6, "id": "695682f5", "metadata": {}, "outputs": [{"data": {"text/html": ["
\n", ""], "text/plain": [""]}, "execution_count": 7, "metadata": {}, "output_type": "execute_result"}], "source": ["from onnxcustom.training.grad_helper import onnx_derivative, DerivativeOptions\n", "\n", "onx = fct.to_onnx()\n", "grad = onnx_derivative(onx)\n", "%onnxview grad"]}, {"cell_type": "code", "execution_count": 7, "id": "72d51122", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["opset: domain='' version=14\n", "opset: domain='com.microsoft.nchwc' version=1\n", "opset: domain='ai.onnx.ml' version=2\n", "opset: domain='com.ms.internal.nhwc' version=1\n", "opset: domain='ai.onnx.training' version=1\n", "opset: domain='ai.onnx.preview.training' version=1\n", "opset: domain='com.microsoft' version=1\n", "opset: domain='com.microsoft.experimental' version=1\n", "input: name='x' type=dtype('float32') shape=(0, 2)\n", "input: name='init' type=dtype('float32') shape=(1, 2)\n", "input: name='init_1' type=dtype('float32') shape=(1, 2)\n", "input: name='init_b10' type=dtype('float32') shape=(1,)\n", "input: name='y_grad' type=dtype('float32') shape=(0, 1)\n", "init: name='init_5' type=dtype('int64') shape=(0,)\n", "init: name='init_2' type=dtype('int64') shape=(0,)\n", "init: name='init_3' type=dtype('int64') shape=(0,)\n", "output: name='x_grad' type=dtype('float32') shape=(0, 2)\n", "output: name='init_grad' type=dtype('float32') shape=(1, 2)\n", "output: name='init_1_grad' type=dtype('float32') shape=(1, 2)\n", "output: name='init_b10_grad' type=dtype('float32') shape=(1,)\n"]}], "source": ["from mlprodict.plotting.text_plot import onnx_text_plot_io, onnx_simple_text_plot\n", "print(onnx_text_plot_io(grad))"]}, {"cell_type": "code", "execution_count": 8, "id": "753199ee", "metadata": {}, "outputs": [], "source": ["from mlprodict.onnx_tools.onnx_manipulations import onnx_rename_names\n", "renamed = onnx_rename_names(grad)"]}, {"cell_type": "code", "execution_count": 9, "id": "2d0b3a76", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["opset: domain='' version=14\n", "opset: domain='com.microsoft.nchwc' version=1\n", "opset: domain='ai.onnx.ml' version=2\n", "opset: domain='com.ms.internal.nhwc' version=1\n", "opset: domain='ai.onnx.training' version=1\n", "opset: domain='ai.onnx.preview.training' version=1\n", "opset: domain='com.microsoft' version=1\n", "opset: domain='com.microsoft.experimental' version=1\n", "input: name='x' type=dtype('float32') shape=(0, 2)\n", "input: name='init' type=dtype('float32') shape=(1, 2)\n", "input: name='init_1' type=dtype('float32') shape=(1, 2)\n", "input: name='init_b10' type=dtype('float32') shape=(1,)\n", "input: name='y_grad' type=dtype('float32') shape=(0, 1)\n", "init: name='i0' type=dtype('int64') shape=(0,) -- array([0], dtype=int64)\n", "init: name='i1' type=dtype('int64') shape=(0,) -- array([1], dtype=int64)\n", "init: name='i2' type=dtype('int64') shape=(0,) -- array([2], dtype=int64)\n", "Mul(x, x) -> r0\n", " Mul(r0, init) -> r1\n", " Shape(r1) -> r32\n", "Mul(x, init_1) -> r2\n", " Add(r2, r1) -> r3\n", " Slice(r3, i1, i2, i1) -> r4\n", " Squeeze(r4, i1) -> r5\n", " Shape(r5) -> r18\n", " Slice(r3, i0, i1, i1) -> r6\n", " Squeeze(r6, i1) -> r7\n", " Add(r7, r5) -> r8\n", " Add(r8, init_b10) -> r9\n", " Shape(r9) -> r10\n", " Reshape(y_grad, r10, allowzero=0) -> r11\n", "Shape(init_b10) -> r12\n", "Shape(r8) -> r13\n", " BroadcastGradientArgs(r13, r12) -> r14, r15\n", " ReduceSum(r11, r14, keepdims=1, noop_with_empty_axes=1) -> r16\n", " Reshape(r16, r13, allowzero=0) -> r17\n", "Shape(r7) -> r19\n", " BroadcastGradientArgs(r19, r18) -> r20, r21\n", " ReduceSum(r17, r21, keepdims=1, noop_with_empty_axes=1) -> r22\n", " Reshape(r22, r18, allowzero=0) -> r23\n", " Unsqueeze(r23, i1) -> r24\n", " Shape(r3) -> r25\n", " SliceGrad(r24, r25, i1, i2, i1) -> r26\n", " ReduceSum(r17, r20, keepdims=1, noop_with_empty_axes=1) -> r27\n", " Reshape(r27, r19, allowzero=0) -> r28\n", " Unsqueeze(r28, i1) -> r29\n", " SliceGrad(r29, r25, i0, i1, i1) -> r30\n", " Sum(r30, r26) -> r31\n", " Shape(r2) -> r33\n", " BroadcastGradientArgs(r33, r32) -> r34, r35\n", " ReduceSum(r31, r35, keepdims=1, noop_with_empty_axes=1) -> r36\n", " Reshape(r36, r32, allowzero=0) -> r37\n", " Mul(r37, init) -> r38\n", "Shape(init) -> r39\n", "Shape(r0) -> r40\n", " BroadcastGradientArgs(r40, r39) -> r41, r42\n", " ReduceSum(r38, r41, keepdims=1, noop_with_empty_axes=1) -> r43\n", " Reshape(r43, r40, allowzero=0) -> r44\n", " Mul(r44, x) -> r45\n", "ReduceSum(r31, r34, keepdims=1, noop_with_empty_axes=1) -> r46\n", " Reshape(r46, r33, allowzero=0) -> r47\n", " Mul(r47, init_1) -> r48\n", "Shape(init_1) -> r49\n", "Shape(x) -> r50\n", " BroadcastGradientArgs(r50, r49) -> r51, r52\n", " ReduceSum(r48, r51, keepdims=1, noop_with_empty_axes=1) -> r53\n", " Reshape(r53, r50, allowzero=0) -> r54\n", " Sum(r54, r45, r45) -> x_grad\n", "ReduceSum(r11, r15, keepdims=1, noop_with_empty_axes=1) -> r55\n", " Reshape(r55, r12, allowzero=0) -> init_b10_grad\n", "Mul(r37, r0) -> r56\n", " ReduceSum(r56, r42, keepdims=1, noop_with_empty_axes=1) -> r57\n", " Reshape(r57, r39, allowzero=0) -> init_grad\n", "Mul(r47, x) -> r58\n", " ReduceSum(r58, r52, keepdims=1, noop_with_empty_axes=1) -> r59\n", " Reshape(r59, r49, allowzero=0) -> init_1_grad\n", "output: name='x_grad' type=dtype('float32') shape=(0, 2)\n", "output: name='init_grad' type=dtype('float32') shape=(1, 2)\n", "output: name='init_1_grad' type=dtype('float32') shape=(1, 2)\n", "output: name='init_b10_grad' type=dtype('float32') shape=(1,)\n"]}], "source": ["print(onnx_simple_text_plot(renamed))"]}, {"cell_type": "code", "execution_count": 10, "id": "d0f09fbc", "metadata": {}, "outputs": [{"data": {"text/plain": ["{'Add',\n", " 'BroadcastGradientArgs',\n", " 'Mul',\n", " 'ReduceSum',\n", " 'Reshape',\n", " 'Shape',\n", " 'Slice',\n", " 'SliceGrad',\n", " 'Squeeze',\n", " 'Sum',\n", " 'Unsqueeze'}"]}, "execution_count": 11, "metadata": {}, "output_type": "execute_result"}], "source": ["set(n.op_type for n in grad.graph.node)"]}, {"cell_type": "markdown", "id": "6e3558bf", "metadata": {}, "source": ["The resulting graph assumes the gradient for `y_grad` is known. That's the case for a layer in a neural network. In our case, this gradient should come from the loss. Let's add it to the graph."]}, {"cell_type": "markdown", "id": "701eead0", "metadata": {}, "source": ["## Add a square loss"]}, {"cell_type": "code", "execution_count": 11, "id": "4d885b0d", "metadata": {}, "outputs": [{"data": {"text/html": ["
\n", ""], "text/plain": [""]}, "execution_count": 12, "metadata": {}, "output_type": "execute_result"}], "source": ["from onnxcustom.utils.orttraining_helper import add_loss_output\n", "onx_loss = add_loss_output(onx)\n", "\n", "%onnxview onx_loss"]}, {"cell_type": "code", "execution_count": 12, "id": "6b451748", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["opset: domain='' version=14\n", "input: name='x' type=dtype('float32') shape=(0, 2)\n", "input: name='label' type=dtype('float32') shape=(0, 1)\n", "init: name='init' type=dtype('float32') shape=(0,) -- array([ 0.5 , -0.33333], dtype=float32)\n", "init: name='init_1' type=dtype('float32') shape=(0,) -- array([-2., 3.], dtype=float32)\n", "init: name='init_2' type=dtype('int64') shape=(0,) -- array([1], dtype=int64)\n", "init: name='init_3' type=dtype('int64') shape=(0,) -- array([2], dtype=int64)\n", "init: name='init_5' type=dtype('int64') shape=(0,) -- array([0], dtype=int64)\n", "init: name='init_b10' type=dtype('float32') shape=(0,) -- array([-1.], dtype=float32)\n", "init: name='init_b11' type=dtype('int64') shape=(0,) -- array([-1, 1], dtype=int64)\n", "Mul(x, x) -> out_mul_0\n", " Mul(out_mul_0, init) -> out_mul_0_1\n", "Mul(x, init_1) -> out_mul_0_2\n", " Add(out_mul_0_2, out_mul_0_1) -> out_add_0\n", " Slice(out_add_0, init_2, init_3, init_2) -> out_sli_0\n", " Squeeze(out_sli_0, init_2) -> out_squ_0\n", " Slice(out_add_0, init_5, init_2, init_2) -> out_sli_0_1\n", " Squeeze(out_sli_0_1, init_2) -> out_squ_0_1\n", " Add(out_squ_0_1, out_squ_0) -> out_add_0_1\n", " Add(out_add_0_1, init_b10) -> out_add_0_2\n", " Reshape(out_add_0_2, init_b11) -> y\n", " Sub(y, label) -> loss_diff\n", " Mul(loss_diff, loss_diff) -> loss_diff_2\n", " ReduceSum(loss_diff_2) -> loss\n", "output: name='loss' type=dtype('float32') shape=(1, 1)\n", "output: name='y' type=dtype('float32') shape=(0, 1)\n"]}], "source": ["print(onnx_simple_text_plot(onx_loss))"]}, {"cell_type": "markdown", "id": "db99d22e", "metadata": {}, "source": ["The graph has 5 inputs: `x`, `label` or the expected target, and the weights and two outputs, the function output and the loss. We don't need the first one so we remove it."]}, {"cell_type": "code", "execution_count": 13, "id": "15ca92ad", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["opset: domain='' version=14\n", "input: name='x' type=dtype('float32') shape=(0, 2)\n", "input: name='label' type=dtype('float32') shape=(0, 1)\n", "init: name='init' type=dtype('float32') shape=(0,) -- array([ 0.5 , -0.33333], dtype=float32)\n", "init: name='init_1' type=dtype('float32') shape=(0,) -- array([-2., 3.], dtype=float32)\n", "init: name='init_2' type=dtype('int64') shape=(0,) -- array([1], dtype=int64)\n", "init: name='init_3' type=dtype('int64') shape=(0,) -- array([2], dtype=int64)\n", "init: name='init_5' type=dtype('int64') shape=(0,) -- array([0], dtype=int64)\n", "init: name='init_b10' type=dtype('float32') shape=(0,) -- array([-1.], dtype=float32)\n", "init: name='init_b11' type=dtype('int64') shape=(0,) -- array([-1, 1], dtype=int64)\n", "Mul(x, x) -> out_mul_0\n", " Mul(out_mul_0, init) -> out_mul_0_1\n", "Mul(x, init_1) -> out_mul_0_2\n", " Add(out_mul_0_2, out_mul_0_1) -> out_add_0\n", " Slice(out_add_0, init_5, init_2, init_2) -> out_sli_0_1\n", " Squeeze(out_sli_0_1, init_2) -> out_squ_0_1\n", " Slice(out_add_0, init_2, init_3, init_2) -> out_sli_0\n", " Squeeze(out_sli_0, init_2) -> out_squ_0\n", " Add(out_squ_0_1, out_squ_0) -> out_add_0_1\n", " Add(out_add_0_1, init_b10) -> out_add_0_2\n", " Reshape(out_add_0_2, init_b11) -> y\n", " Sub(y, label) -> loss_diff\n", " Mul(loss_diff, loss_diff) -> loss_diff_2\n", " ReduceSum(loss_diff_2) -> loss\n", "output: name='loss' type=dtype('float32') shape=(1, 1)\n"]}], "source": ["from mlprodict.onnx_tools.onnx_manipulations import select_model_inputs_outputs\n", "\n", "onx_loss_only = select_model_inputs_outputs(onx_loss, outputs=['loss'])\n", "print(onnx_simple_text_plot(onx_loss_only))"]}, {"cell_type": "markdown", "id": "3d3a5477", "metadata": {}, "source": ["## Gradient again : loss + retropropagation"]}, {"cell_type": "code", "execution_count": 14, "id": "0496baf6", "metadata": {"scrolled": false}, "outputs": [{"data": {"text/html": ["
\n", ""], "text/plain": [""]}, "execution_count": 15, "metadata": {}, "output_type": "execute_result"}], "source": ["grad_loss = onnx_rename_names(onnx_derivative(\n", " onx_loss_only, options=DerivativeOptions.FillGrad | DerivativeOptions.KeepOutputs))\n", "%onnxview grad_loss"]}, {"cell_type": "code", "execution_count": 15, "id": "e497697d", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["opset: domain='' version=14\n", "opset: domain='com.microsoft.nchwc' version=1\n", "opset: domain='ai.onnx.ml' version=2\n", "opset: domain='com.ms.internal.nhwc' version=1\n", "opset: domain='ai.onnx.training' version=1\n", "opset: domain='ai.onnx.preview.training' version=1\n", "opset: domain='com.microsoft' version=1\n", "opset: domain='com.microsoft.experimental' version=1\n", "input: name='x' type=dtype('float32') shape=(0, 2)\n", "input: name='label' type=dtype('float32') shape=(0, 1)\n", "input: name='init' type=dtype('float32') shape=(1, 2)\n", "input: name='init_1' type=dtype('float32') shape=(1, 2)\n", "input: name='init_b10' type=dtype('float32') shape=(1,)\n", "init: name='i0' type=dtype('int64') shape=(0,) -- array([-1, 1], dtype=int64)\n", "init: name='i1' type=dtype('int64') shape=(0,) -- array([0], dtype=int64)\n", "init: name='i2' type=dtype('int64') shape=(0,) -- array([1], dtype=int64)\n", "init: name='i3' type=dtype('int64') shape=(0,) -- array([2], dtype=int64)\n", "Mul(x, init_1) -> r0\n", " Shape(r0) -> r47\n", "Mul(x, x) -> r1\n", " Mul(r1, init) -> r2\n", " Add(r0, r2) -> r3\n", " Slice(r3, i1, i2, i2) -> r4\n", " Squeeze(r4, i2) -> r5\n", " Shape(r5) -> r33\n", " Slice(r3, i2, i3, i2) -> r6\n", " Squeeze(r6, i2) -> r7\n", " Add(r5, r7) -> r8\n", " Add(r8, init_b10) -> r9\n", " Reshape(r9, i0, allowzero=0) -> r10\n", " Sub(r10, label) -> r11\n", " Mul(r11, r11) -> r12\n", " ReduceSum(r12, keepdims=1, noop_with_empty_axes=0) -> loss\n", " Shape(loss) -> r76\n", " ConstantOfShape(r76) -> r14\n", " Shape(r12) -> r13\n", " Expand(r14, r13) -> r15\n", " Mul(r15, r11) -> r16\n", " Sum(r16, r16) -> r17\n", "Shape(label) -> r18\n", "Shape(r10) -> r19\n", " BroadcastGradientArgs(r19, r18) -> r20, r21\n", " ReduceSum(r17, r20, keepdims=1, noop_with_empty_axes=1) -> r22\n", " Reshape(r22, r19, allowzero=0) -> r23\n", "Shape(r9) -> r24\n", " Reshape(r23, r24, allowzero=0) -> r25\n", "Shape(init_b10) -> r26\n", "Shape(r8) -> r27\n", " BroadcastGradientArgs(r27, r26) -> r28, r29\n", " ReduceSum(r25, r28, keepdims=1, noop_with_empty_axes=1) -> r30\n", " Reshape(r30, r27, allowzero=0) -> r31\n", "Shape(r7) -> r32\n", " BroadcastGradientArgs(r33, r32) -> r34, r35\n", " ReduceSum(r31, r34, keepdims=1, noop_with_empty_axes=1) -> r36\n", " Reshape(r36, r33, allowzero=0) -> r37\n", " Unsqueeze(r37, i2) -> r38\n", " Shape(r3) -> r39\n", " SliceGrad(r38, r39, i1, i2, i2) -> r40\n", " ReduceSum(r31, r35, keepdims=1, noop_with_empty_axes=1) -> r41\n", " Reshape(r41, r32, allowzero=0) -> r42\n", " Unsqueeze(r42, i2) -> r43\n", " SliceGrad(r43, r39, i2, i3, i2) -> r44\n", " Sum(r44, r40) -> r45\n", " Shape(r2) -> r46\n", " BroadcastGradientArgs(r47, r46) -> r48, r49\n", " ReduceSum(r45, r48, keepdims=1, noop_with_empty_axes=1) -> r50\n", " Reshape(r50, r47, allowzero=0) -> r51\n", " Mul(r51, init_1) -> r52\n", "Shape(init_1) -> r53\n", "Shape(x) -> r54\n", " BroadcastGradientArgs(r54, r53) -> r55, r56\n", " ReduceSum(r52, r55, keepdims=1, noop_with_empty_axes=1) -> r57\n", " Reshape(r57, r54, allowzero=0) -> r58\n", "ReduceSum(r45, r49, keepdims=1, noop_with_empty_axes=1) -> r59\n", " Reshape(r59, r46, allowzero=0) -> r60\n", " Mul(r60, init) -> r61\n", "Shape(init) -> r62\n", "Shape(r1) -> r63\n", " BroadcastGradientArgs(r63, r62) -> r64, r65\n", " ReduceSum(r61, r64, keepdims=1, noop_with_empty_axes=1) -> r66\n", " Reshape(r66, r63, allowzero=0) -> r67\n", " Mul(r67, x) -> r68\n", " Sum(r68, r68, r58) -> x_grad\n", "ReduceSum(r17, r21, keepdims=1, noop_with_empty_axes=1) -> r69\n", " Reshape(r69, r18, allowzero=0) -> r70\n", " Neg(r70) -> label_grad\n", "ReduceSum(r25, r29, keepdims=1, noop_with_empty_axes=1) -> r71\n", " Reshape(r71, r26, allowzero=0) -> init_b10_grad\n", "Mul(r51, x) -> r72\n", " ReduceSum(r72, r56, keepdims=1, noop_with_empty_axes=1) -> r73\n", " Reshape(r73, r53, allowzero=0) -> init_1_grad\n", "Mul(r60, r1) -> r74\n", " ReduceSum(r74, r65, keepdims=1, noop_with_empty_axes=1) -> r75\n", " Reshape(r75, r62, allowzero=0) -> init_grad\n", "output: name='x_grad' type=dtype('float32') shape=(0, 2)\n", "output: name='label_grad' type=dtype('float32') shape=(0, 1)\n", "output: name='init_grad' type=dtype('float32') shape=(1, 2)\n", "output: name='init_1_grad' type=dtype('float32') shape=(1, 2)\n", "output: name='init_b10_grad' type=dtype('float32') shape=(1,)\n", "output: name='loss' type=dtype('float32') shape=(1, 1)\n"]}], "source": ["print(onnx_simple_text_plot(grad_loss))"]}, {"cell_type": "markdown", "id": "ba52a6bd", "metadata": {}, "source": ["Let's compute the gradient."]}, {"cell_type": "code", "execution_count": 16, "id": "8fdbf27e", "metadata": {}, "outputs": [{"data": {"text/plain": ["array([[0., 0.],\n", " [1., 0.],\n", " [0., 1.],\n", " [1., 1.],\n", " [2., 2.]], dtype=float32)"]}, "execution_count": 17, "metadata": {}, "output_type": "execute_result"}], "source": ["x"]}, {"cell_type": "code", "execution_count": 17, "id": "e7104f1c", "metadata": {}, "outputs": [{"data": {"text/plain": ["array([[-1. ],\n", " [-2.5 ],\n", " [ 1.6666701 ],\n", " [ 0.16667008],\n", " [ 1.6666799 ]], dtype=float32)"]}, "execution_count": 18, "metadata": {}, "output_type": "execute_result"}], "source": ["y = fct(x)\n", "y"]}, {"cell_type": "code", "execution_count": 18, "id": "3f7240ef", "metadata": {"scrolled": false}, "outputs": [], "source": ["from mlprodict.onnxrt import OnnxInference\n", "\n", "oinf = OnnxInference(grad_loss, runtime='onnxruntime1')"]}, {"cell_type": "code", "execution_count": 19, "id": "550a9073", "metadata": {"scrolled": false}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["{'init_1_grad': array([[109.333244, 102.666565]], dtype=float32),\n", " 'init_b10_grad': array([76.6666], dtype=float32),\n", " 'init_grad': array([[193.33316, 186.66649]], dtype=float32),\n", " 'label_grad': array([[ -4. ],\n", " [-12. ],\n", " [ -5.33332 ],\n", " [-13.333321],\n", " [-41.99996 ]], dtype=float32),\n", " 'loss': array([[532.5546]], dtype=float32),\n", " 'x_grad': array([[ 2. , 1.33332 ],\n", " [ 54. , 3.99996 ],\n", " [ 2.66666 , 33.777676],\n", " [ 59.999943, 84.44432 ],\n", " [356.99966 , 517.9994 ]], dtype=float32)}\n"]}], "source": ["import pprint\n", "\n", "init = numpy.array([[2, 3]], dtype=numpy.float32)\n", "init_1 = numpy.array([[0.5, 0.33333]], dtype=numpy.float32)\n", "init_b10 = numpy.array([1], dtype=numpy.float32)\n", "result = oinf.run({'x': x, 'label': y, \n", " 'init': init, 'init_1': init_1, 'init_b10': init_b10})\n", "pprint.pprint(result)"]}, {"cell_type": "markdown", "id": "211dec4a", "metadata": {}, "source": ["We could use this gradient to implement a stochastic gradient descent in python. Two comments:\n", "* If we implement it this with numpy, it cannot work on GPU.\n", "* If we use OrtValue (tensor from onnxruntime), how to do simple addition between OrtValue ?\n", "\n", "We need to implemented the second option. A simple addition between two OrtValue must be done with an ONNX graph."]}, {"cell_type": "markdown", "id": "961435d5", "metadata": {}, "source": ["## TrainingSession"]}, {"cell_type": "code", "execution_count": 20, "id": "18fd1c4d", "metadata": {}, "outputs": [{"data": {"text/plain": ["((100, 2), (100, 1))"]}, "execution_count": 21, "metadata": {}, "output_type": "execute_result"}], "source": ["X = numpy.random.randn(100, 2).astype(numpy.float32) / 10\n", "y = fct(X) + (numpy.random.randn(100, 1) / 1000).astype(numpy.float32)\n", "X.shape, y.shape"]}, {"cell_type": "code", "execution_count": 21, "id": "60e1c260", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["opset: domain='' version=14\n", "input: name='x' type=dtype('float32') shape=(0, 2)\n", "init: name='init' type=dtype('float32') shape=(0,) -- array([ 0.5 , -0.33333], dtype=float32)\n", "init: name='init_1' type=dtype('float32') shape=(0,) -- array([-2., 3.], dtype=float32)\n", "init: name='init_2' type=dtype('int64') shape=(0,) -- array([1], dtype=int64)\n", "init: name='init_3' type=dtype('int64') shape=(0,) -- array([2], dtype=int64)\n", "init: name='init_5' type=dtype('int64') shape=(0,) -- array([0], dtype=int64)\n", "init: name='init_b10' type=dtype('float32') shape=(0,) -- array([-1.], dtype=float32)\n", "init: name='init_b11' type=dtype('int64') shape=(0,) -- array([-1, 1], dtype=int64)\n", "Mul(x, x) -> out_mul_0\n", " Mul(out_mul_0, init) -> out_mul_0_1\n", "Mul(x, init_1) -> out_mul_0_2\n", " Add(out_mul_0_2, out_mul_0_1) -> out_add_0\n", " Slice(out_add_0, init_2, init_3, init_2) -> out_sli_0\n", " Squeeze(out_sli_0, init_2) -> out_squ_0\n", " Slice(out_add_0, init_5, init_2, init_2) -> out_sli_0_1\n", " Squeeze(out_sli_0_1, init_2) -> out_squ_0_1\n", " Add(out_squ_0_1, out_squ_0) -> out_add_0_1\n", " Add(out_add_0_1, init_b10) -> out_add_0_2\n", " Reshape(out_add_0_2, init_b11) -> y\n", "output: name='y' type=dtype('float32') shape=(0, 1)\n"]}], "source": ["print(onnx_simple_text_plot(onx))"]}, {"cell_type": "code", "execution_count": 22, "id": "f6bfc26e", "metadata": {}, "outputs": [{"data": {"text/plain": ["OrtGradientOptimizer(model_onnx='ir_version...', weights_to_train=['init', 'init_1', 'init_b10'], loss_output_name='loss', max_iter=100, training_optimizer_name='SGDOptimizer', batch_size=5, learning_rate=LearningRateSGD(eta0=0.1, alpha=0.0001, power_t=0.25, learning_rate='invscaling'), value=0.03162277660168379, device='cpu', warm_start=False, verbose=0, validation_every=10, saved_gradient=None, sample_weight_name='weight')"]}, "execution_count": 23, "metadata": {}, "output_type": "execute_result"}], "source": ["from onnxcustom.training.optimizers import OrtGradientOptimizer\n", "\n", "train_session = OrtGradientOptimizer(\n", " onx_loss, ['init', 'init_1', 'init_b10'], learning_rate=1e-1,\n", " batch_size=5, max_iter=100)\n", "\n", "train_session.fit(X, y)"]}, {"cell_type": "code", "execution_count": 23, "id": "5880b69f", "metadata": {}, "outputs": [{"data": {"text/plain": ["{'init': array([[-0.34785354, 1.1399053 ]], dtype=float32),\n", " 'init_1': array([[-1.9156165, 2.4292002]], dtype=float32),\n", " 'init_b10': array([-1.0016667], dtype=float32)}"]}, "execution_count": 24, "metadata": {}, "output_type": "execute_result"}], "source": ["train_session.trained_coef_"]}, {"cell_type": "code", "execution_count": 24, "id": "f19f97fe", "metadata": {}, "outputs": [{"data": {"text/plain": ["[0.0036812867, 0.0038135047, 0.0037041684, 0.0037206002, 0.0032002896]"]}, "execution_count": 25, "metadata": {}, "output_type": "execute_result"}], "source": ["train_session.train_losses_[-5:]"]}, {"cell_type": "code", "execution_count": 25, "id": "82cbd6cf", "metadata": {}, "outputs": [{"data": {"image/png": "\n", "text/plain": ["
"]}, "metadata": {"needs_background": "light"}, "output_type": "display_data"}], "source": ["import pandas\n", "\n", "pandas.DataFrame({'loss': train_session.train_losses_}).plot();"]}, {"cell_type": "markdown", "id": "4d3ab358", "metadata": {}, "source": ["## Fordward backward: TrainingAgent\n", "\n", "This second implementation uses [TrainingAgent](http://www.xavierdupre.fr/app/onnxcustom/helpsphinx/api/onnxruntime_python/training_partial.html#trainingagent)."]}, {"cell_type": "code", "execution_count": 26, "id": "809cb2f3", "metadata": {}, "outputs": [], "source": ["from onnxcustom.training.optimizers_partial import OrtGradientForwardBackwardOptimizer\n", "\n", "train_session = OrtGradientForwardBackwardOptimizer(\n", " onx, ['init', 'init_1', 'init_b10'], learning_rate=1e-1, \n", " batch_size=2, max_iter=100)"]}, {"cell_type": "code", "execution_count": 27, "id": "e47bfdf6", "metadata": {}, "outputs": [{"data": {"text/plain": ["OrtGradientForwardBackwardOptimizer(model_onnx='ir_version...', weights_to_train=['init', 'init_1', 'init_b10'], loss_output_name='loss', max_iter=100, training_optimizer_name='SGDOptimizer', batch_size=2, learning_rate=LearningRateSGD(eta0=0.1, alpha=0.0001, power_t=0.25, learning_rate='invscaling'), value=0.03162277660168379, device='cpu', warm_start=False, verbose=0, validation_every=10, learning_loss=SquareLearningLoss(), enable_logging=False, weight_name=None, learning_penalty=NoLearningPenalty(), exc=True)"]}, "execution_count": 28, "metadata": {}, "output_type": "execute_result"}], "source": ["train_session.fit(X, y)"]}, {"cell_type": "code", "execution_count": 28, "id": "8f34bf7a", "metadata": {}, "outputs": [{"data": {"text/plain": ["[0.00040441833, 0.00037421435, 0.00049950054, 0.00042527347, 0.00031072882]"]}, "execution_count": 29, "metadata": {}, "output_type": "execute_result"}], "source": ["train_session.train_losses_[-5:]"]}, {"cell_type": "code", "execution_count": 29, "id": "ba57e16f", "metadata": {}, "outputs": [{"data": {"image/png": "\n", "text/plain": ["
"]}, "metadata": {"needs_background": "light"}, "output_type": "display_data"}], "source": ["pandas.DataFrame({'loss': train_session.train_losses_}).plot();"]}, {"cell_type": "code", "execution_count": 30, "id": "5da877fd", "metadata": {}, "outputs": [{"data": {"text/plain": ["{'init': ,\n", " 'init_1': ,\n", " 'init_b10': }"]}, "execution_count": 31, "metadata": {}, "output_type": "execute_result"}], "source": ["train_session.trained_coef_"]}, {"cell_type": "code", "execution_count": 31, "id": "fffb703d", "metadata": {}, "outputs": [{"data": {"text/plain": ["{'init': array([[-0.35357383, 0.6850407 ]], dtype=float32),\n", " 'init_1': array([[-1.916494 , 2.8799832]], dtype=float32),\n", " 'init_b10': array([-1.0036615], dtype=float32)}"]}, "execution_count": 32, "metadata": {}, "output_type": "execute_result"}], "source": ["{k: v.numpy() for k, v in train_session.trained_coef_.items()}"]}, {"cell_type": "markdown", "id": "8caa3421", "metadata": {}, "source": ["Not the same weights? What about the prediction?"]}, {"cell_type": "code", "execution_count": 32, "id": "896d9bc4", "metadata": {}, "outputs": [], "source": ["trained_onx = train_session.get_trained_onnx()"]}, {"cell_type": "code", "execution_count": 33, "id": "360af856", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["opset: domain='' version=14\n", "input: name='x' type=dtype('float32') shape=(0, 2)\n", "init: name='init' type=dtype('float32') shape=(0,) -- array([-0.35357383, 0.6850407 ], dtype=float32)\n", "init: name='init_1' type=dtype('float32') shape=(0,) -- array([-1.916494 , 2.8799832], dtype=float32)\n", "init: name='init_2' type=dtype('int64') shape=(0,) -- array([1], dtype=int64)\n", "init: name='init_3' type=dtype('int64') shape=(0,) -- array([2], dtype=int64)\n", "init: name='init_5' type=dtype('int64') shape=(0,) -- array([0], dtype=int64)\n", "init: name='init_b10' type=dtype('float32') shape=(0,) -- array([-1.0036615], dtype=float32)\n", "init: name='init_b11' type=dtype('int64') shape=(0,) -- array([-1, 1], dtype=int64)\n", "Mul(x, x) -> out_mul_0\n", " Mul(out_mul_0, init) -> out_mul_0_1\n", "Mul(x, init_1) -> out_mul_0_2\n", " Add(out_mul_0_2, out_mul_0_1) -> out_add_0\n", " Slice(out_add_0, init_2, init_3, init_2) -> out_sli_0\n", " Squeeze(out_sli_0, init_2) -> out_squ_0\n", " Slice(out_add_0, init_5, init_2, init_2) -> out_sli_0_1\n", " Squeeze(out_sli_0_1, init_2) -> out_squ_0_1\n", " Add(out_squ_0_1, out_squ_0) -> out_add_0_1\n", " Add(out_add_0_1, init_b10) -> out_add_0_2\n", " Reshape(out_add_0_2, init_b11) -> y\n", "output: name='y' type=dtype('float32') shape=(0, 1)\n"]}], "source": ["print(onnx_simple_text_plot(trained_onx))"]}, {"cell_type": "code", "execution_count": 34, "id": "0aee38d3", "metadata": {}, "outputs": [{"data": {"text/plain": ["array([[-0.6123954],\n", " [-1.303561 ],\n", " [-2.0257921],\n", " [-1.2778704],\n", " [-0.9708453]], dtype=float32)"]}, "execution_count": 35, "metadata": {}, "output_type": "execute_result"}], "source": ["oinf = OnnxInference(trained_onx)\n", "oinf.run({'x': X})['y'][:5]"]}, {"cell_type": "code", "execution_count": 35, "id": "dc82cbb8", "metadata": {}, "outputs": [{"data": {"text/plain": ["array([[-0.58675164],\n", " [-1.3148587 ],\n", " [-2.0666485 ],\n", " [-1.272753 ],\n", " [-0.95404863]], dtype=float32)"]}, "execution_count": 36, "metadata": {}, "output_type": "execute_result"}], "source": ["y[:5]"]}, {"cell_type": "markdown", "id": "4a2ac16b", "metadata": {}, "source": ["It works."]}, {"cell_type": "markdown", "id": "854502c0", "metadata": {}, "source": ["## MLPregressor"]}, {"cell_type": "code", "execution_count": 36, "id": "c9c14cfc", "metadata": {}, "outputs": [], "source": ["import warnings\n", "import time\n", "import numpy\n", "import matplotlib.pyplot as plt\n", "from pandas import DataFrame\n", "from onnxruntime import get_device\n", "from sklearn.datasets import make_regression\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.neural_network import MLPRegressor\n", "from skl2onnx import to_onnx\n", "\n", "\n", "X, y = make_regression(1000, n_features=100, bias=2)\n", "X = X.astype(numpy.float32)\n", "y = y.astype(numpy.float32)\n", "X_train, X_test, y_train, y_test = train_test_split(X, y)"]}, {"cell_type": "code", "execution_count": 37, "id": "c36d7f35", "metadata": {}, "outputs": [], "source": ["batch_size = 15\n", "max_iter = 100\n", "\n", "nn = MLPRegressor(hidden_layer_sizes=(50, 10), max_iter=max_iter,\n", " solver='sgd', learning_rate_init=5e-5,\n", " n_iter_no_change=max_iter * 3, batch_size=batch_size,\n", " learning_rate=\"invscaling\",\n", " # default values\n", " momentum=0.9, nesterovs_momentum=True, power_t=0.5)\n", "\n", "with warnings.catch_warnings():\n", " warnings.simplefilter('ignore')\n", " nn.fit(X_train, y_train)"]}, {"cell_type": "markdown", "id": "4d5a10f9", "metadata": {}, "source": ["Conversion to ONNX"]}, {"cell_type": "code", "execution_count": 38, "id": "2e07fa9e", "metadata": {}, "outputs": [], "source": ["from onnxcustom.utils.onnx_helper import onnx_rename_weights\n", "onx = to_onnx(nn, X_train[:1].astype(numpy.float32), target_opset=15)\n", "onx = onnx_rename_weights(onx)"]}, {"cell_type": "code", "execution_count": 39, "id": "49050eb7", "metadata": {}, "outputs": [], "source": ["train_session = OrtGradientForwardBackwardOptimizer(\n", " onx, device='cpu', learning_rate=5e-5,\n", " warm_start=False, max_iter=max_iter, batch_size=batch_size)"]}, {"cell_type": "code", "execution_count": 40, "id": "d658b104", "metadata": {}, "outputs": [{"data": {"text/plain": ["OrtGradientForwardBackwardOptimizer(model_onnx='ir_version...', weights_to_train=\"['I0_coeff...\", loss_output_name='loss', max_iter=100, training_optimizer_name='SGDOptimizer', batch_size=15, learning_rate=LearningRateSGD(eta0=5e-05, alpha=0.0001, power_t=0.25, learning_rate='invscaling'), value=1.5811388300841898e-05, device='cpu', warm_start=False, verbose=0, validation_every=10, learning_loss=SquareLearningLoss(), enable_logging=False, weight_name=None, learning_penalty=NoLearningPenalty(), exc=True)"]}, "execution_count": 41, "metadata": {}, "output_type": "execute_result"}], "source": ["train_session.fit(X_train, y_train)"]}, {"cell_type": "code", "execution_count": 41, "id": "2c2e7ccc", "metadata": {}, "outputs": [{"data": {"image/png": "\n", "text/plain": ["
"]}, "metadata": {"needs_background": "light"}, "output_type": "display_data"}], "source": ["pandas.DataFrame(dict(skl_loss=nn.loss_curve_, ort_loss=train_session.train_losses_)).plot();"]}, {"cell_type": "code", "execution_count": 42, "id": "018f9ae3", "metadata": {}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["C:\\Python395_x64\\lib\\site-packages\\sklearn\\neural_network\\_multilayer_perceptron.py:692: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (100) reached and the optimization hasn't converged yet.\n", " warnings.warn(\n"]}, {"name": "stdout", "output_type": "stream", "text": ["1.98 s \u00b1 0 ns per loop (mean \u00b1 std. dev. of 1 run, 1 loop each)\n"]}], "source": ["%timeit -n 1 -r 1 nn.fit(X_train, y_train)"]}, {"cell_type": "code", "execution_count": 43, "id": "1a956825", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["1.88 s \u00b1 0 ns per loop (mean \u00b1 std. dev. of 1 run, 1 loop each)\n"]}], "source": ["%timeit -n 1 -r 1 train_session.fit(X_train, y_train)"]}, {"cell_type": "markdown", "id": "bb7d75d9", "metadata": {}, "source": ["## Not exactly the same: Nesterov?"]}, {"cell_type": "code", "execution_count": 44, "id": "534dfc61", "metadata": {}, "outputs": [], "source": ["from onnxcustom.training.sgd_learning_rate import LearningRateSGDNesterov\n", "\n", "train_session2 = OrtGradientForwardBackwardOptimizer(\n", " onx, device='cpu', warm_start=False, max_iter=max_iter, batch_size=batch_size,\n", " learning_rate=LearningRateSGDNesterov(1e-5, nesterov=True, momentum=0.9))"]}, {"cell_type": "code", "execution_count": 45, "id": "6e209570", "metadata": {}, "outputs": [{"data": {"text/plain": ["OrtGradientForwardBackwardOptimizer(model_onnx='ir_version...', weights_to_train=\"['I0_coeff...\", loss_output_name='loss', max_iter=100, training_optimizer_name='SGDOptimizer', batch_size=15, learning_rate=LearningRateSGDNesterov(eta0=1e-05, alpha=0.0001, power_t=0.25, learning_rate='invscaling', momentum=0.9, nesterov=True), value=3.162277660168379e-06, device='cpu', warm_start=False, verbose=0, validation_every=10, learning_loss=SquareLearningLoss(), enable_logging=False, weight_name=None, learning_penalty=NoLearningPenalty(), exc=True)"]}, "execution_count": 46, "metadata": {}, "output_type": "execute_result"}], "source": ["train_session2.fit(X_train, y_train)"]}, {"cell_type": "code", "execution_count": 46, "id": "788fee10", "metadata": {}, "outputs": [{"data": {"image/png": "\n", "text/plain": ["
"]}, "metadata": {"needs_background": "light"}, "output_type": "display_data"}], "source": ["pandas.DataFrame(dict(skl_loss=nn.loss_curve_, \n", " ort_loss=train_session.train_losses_,\n", " ort_loss2=train_session2.train_losses_)).plot();"]}, {"cell_type": "code", "execution_count": 47, "id": "331120b3", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["2.26 s \u00b1 0 ns per loop (mean \u00b1 std. dev. of 1 run, 1 loop each)\n"]}], "source": ["%timeit -n 1 -r 1 train_session2.fit(X_train, y_train)"]}, {"cell_type": "markdown", "id": "0bdaba6b", "metadata": {}, "source": ["## Profiling"]}, {"cell_type": "code", "execution_count": 48, "id": "1505d30b", "metadata": {"scrolled": false}, "outputs": [{"name": "stdout", "output_type": "stream", "text": [" -- 1 1 -- 0.00001 3.78074 -- :18: ()\n", " fit -- 1 1 -- 0.00181 3.78073 -- onnxcustom/onnxcustom/training/optimizers_partial.py:263:fit (fit)\n", " __init__ -- 1 1 -- 0.00002 0.00003 -- onnxcustom/onnxcustom/training/data_loader.py:26:__init__ (__init__)\n", " get_ort_device -- 1 1 -- 0.00000 0.00000 -- onnxruntime_helper.py:55:get_ort_device (get_ort_device)\n", " numpy_to_ort_value -- 2 2 -- 0.00000 0.00001 -- onnxruntime_helper.py:120:numpy_to_ort_value (numpy_to_ort_value) +++\n", " needs_grad -- 3 3 -- 0.00001 0.00001 -- onnxcustom/onnxcustom/training/optimizers_partial.py:99:needs_grad (needs_grad)\n", " needs_grad -- 3 3 -- 0.00000 0.00000 -- onnxcustom/onnxcustom/training/sgd_learning_rate.py:299:needs_grad (needs_grad)\n", " get_full_state -- 101 101 -- 0.00020 0.00093 -- onnxcustom/onnxcustom/training/optimizers_partial.py:147:get_full_state (get_full_state) +++\n", " set_state -- 4 4 -- 0.00008 0.00026 -- onnxcustom/onnxcustom/training/optimizers_partial.py:196:set_state (set_state)\n", " _get_att_state -- 4 4 -- 0.00000 0.00000 -- onnxcustom/onnxcustom/training/optimizers_partial.py:139:_get_att_state (_get_att_state) +++\n", " numpy_to_ort_value -- 24 24 -- 0.00002 0.00011 -- onnxruntime_helper.py:120:numpy_to_ort_value (numpy_to_ort_value) +++\n", " -- 12 12 -- 0.00002 0.00002 -- ~:0: ()\n", " -- 56 56 -- 0.00001 0.00001 -- ~:0: () +++\n", " -- 24 24 -- 0.00000 0.00000 -- ~:0: () +++\n", " -- 1 1 -- 0.00001 0.00095 -- onnxcustom/onnxcustom/training/optimizers_partial.py:311: ()\n", " get_initializer -- 7 7 -- 0.00004 0.00094 -- onnxcustom/onnxcustom/training/ortgradient.py:269:get_initializer (get_initializer) +++\n", " -- 1 1 -- 0.00001 0.00083 -- onnxcustom/onnxcustom/training/optimizers_partial.py:315: ()\n", " get_initializer -- 7 7 -- 0.00004 0.00082 -- onnxcustom/onnxcustom/training/ortgradient.py:269:get_initializer (get_initializer) +++\n", " _iteration -- 100 100 -- 0.41903 3.74610 -- onnxcustom/onnxcustom/training/optimizers_partial.py:397:_iteration (_iteration)\n", " iter_ortvalue -- 6800 6800 -- 0.02838 0.14761 -- onnxcustom/onnxcustom/training/data_loader.py:139:iter_ortvalue (iter_ortvalue)\n", " _next_iter -- 6700 6700 -- 0.00946 0.07207 -- onnxcustom/onnxcustom/training/data_loader.py:93:_next_iter (_next_iter)\n", " -- 6700 6700 -- 0.00245 0.00423 -- ~:0: () +++\n", " -- 6700 6700 -- 0.05838 0.05838 -- ~:0: ()\n", " numpy_to_ort_value -- 13400 13400 -- 0.00658 0.03860 -- onnxruntime_helper.py:120:numpy_to_ort_value (numpy_to_ort_value) +++\n", " -- 6900 6900 -- 0.00467 0.00855 -- ~:0: () +++\n", " forward -- 6700 6700 -- 0.31685 0.44643 -- onnxcustom/onnxcustom/training/ortgradient.py:623:forward (forward)\n", " input_to_ort -- 6700 6700 -- 0.08002 0.11492 -- onnxcustom/onnxcustom/training/ortgradient.py:552:input_to_ort (input_to_ort) +++\n", " save_for_backward -- 6700 6700 -- 0.01032 0.01032 -- onnxcustom/onnxcustom/training/ortgradient.py:604:save_for_backward (save_for_backward)\n", " -- 6700 6700 -- 0.00434 0.00434 -- ~:0: () +++\n", " backward -- 6700 6700 -- 0.43012 0.48957 -- onnxcustom/onnxcustom/training/ortgradient.py:702:backward (backward)\n", " input_to_ort -- 6700 6700 -- 0.04148 0.05262 -- onnxcustom/onnxcustom/training/ortgradient.py:552:input_to_ort (input_to_ort) +++\n", " saved_tensors -- 6700 6700 -- 0.00207 0.00207 -- onnxcustom/onnxcustom/training/ortgradient.py:613:saved_tensors (saved_tensors)\n", " -- 6700 6700 -- 0.00476 0.00476 -- ~:0: ()\n", " loss_gradient -- 6700 6700 -- 0.05841 0.26967 -- onnxcustom/onnxcustom/training/sgd_learning_loss.py:53:loss_gradient (loss_gradient)\n", " clear_binding_inputs -- 6700 6700 -- 0.00545 0.01270 -- onnxcustom/onnxcustom/training/_base_onnx_function.py:130:clear_binding_inputs (clear_binding_inputs)\n", " _cache_in_clear -- 6700 6700 -- 0.00568 0.00725 -- onnxcustom/onnxcustom/training/_base_onnx_function.py:119:_cache_in_clear (_cache_in_clear)\n", " -- 6700 6700 -- 0.00157 0.00157 -- ~:0: () +++\n", " _bind_input_ortvalue -- 13400 13400 -- 0.02070 0.07545 -- onnxcustom/onnxcustom/training/_base_onnx_function.py:159:_bind_input_ortvalue (_bind_input_ortvalue) +++\n", " _call_iobinding -- 6700 6700 -- 0.11997 0.11997 -- onnxcustom/onnxcustom/training/sgd_learning_loss.py:50:_call_iobinding (_call_iobinding)\n", " -- 13400 13400 -- 0.00315 0.00315 -- ~:0: () +++\n", " penalty_loss -- 6700 6700 -- 0.00112 0.00112 -- onnxcustom/onnxcustom/training/sgd_learning_penalty.py:84:penalty_loss (penalty_loss)\n", " update_weights -- 40200 40200 -- 0.00651 0.00651 -- onnxcustom/onnxcustom/training/sgd_learning_penalty.py:95:update_weights (update_weights)\n", " update_weights -- 40200 40200 -- 0.40487 1.94238 -- onnxcustom/onnxcustom/training/sgd_learning_rate.py:345:update_weights (update_weights)\n", " _bind_input_ortvalue -- 201000 201000 -- 0.19630 0.51693 -- onnxcustom/onnxcustom/training/_base_onnx_function.py:159:_bind_input_ortvalue (_bind_input_ortvalue) +++\n", " _bind_output_ortvalue -- 80400 80400 -- 0.07458 0.18952 -- onnxcustom/onnxcustom/training/_base_onnx_function.py:202:_bind_output_ortvalue (_bind_output_ortvalue)\n", " _bio_cache -- 80400 80400 -- 0.04417 0.05406 -- onnxcustom/onnxcustom/training/_base_onnx_function.py:138:_bio_cache (_bio_cache) +++\n", " _bio_ptr -- 80400 80400 -- 0.05222 0.05222 -- onnxcustom/onnxcustom/training/_base_onnx_function.py:155:_bio_ptr (_bio_ptr) +++\n", " _bio_do_bind_out -- 12 12 -- 0.00003 0.00003 -- onnxcustom/onnxcustom/training/_base_onnx_function.py:198:_bio_do_bind_out (_bio_do_bind_out)\n", " -- 80400 80400 -- 0.00863 0.00863 -- ~:0: () +++\n", " _call_iobinding -- 40200 40200 -- 0.63987 0.63987 -- onnxcustom/onnxcustom/training/sgd_learning_rate.py:28:_call_iobinding (_call_iobinding)\n", " value -- 40200 40200 -- 0.00953 0.00953 -- onnxcustom/onnxcustom/training/sgd_learning_rate.py:160:value (value) +++\n", " -- 80400 80400 -- 0.16512 0.16512 -- ~:0: () +++\n", " -- 80400 80400 -- 0.01655 0.01655 -- ~:0: () +++\n", " -- 100 100 -- 0.00026 0.00426 -- ~:0: ()\n", " _mean -- 100 100 -- 0.00163 0.00400 -- site-packages/numpy/core/_methods.py:162:_mean (_mean)\n", " _count_reduce_items -- 100 100 -- 0.00097 0.00107 -- site-packages/numpy/core/_methods.py:66:_count_reduce_items (_count_reduce_items)\n", " -- 200 200 -- 0.00010 0.00010 -- ~:0: ()\n", " -- 100 100 -- 0.00004 0.00004 -- ~:0: ()\n", " -- 100 100 -- 0.00109 0.00109 -- ~:0: ()\n", " -- 100 100 -- 0.00006 0.00006 -- ~:0: () +++\n", " -- 100 100 -- 0.00004 0.00004 -- ~:0: () +++\n", " -- 200 200 -- 0.00007 0.00007 -- ~:0: ()\n", " -- 100 100 -- 0.00358 0.00358 -- ~:0: ()\n", " -- 6700 6700 -- 0.00169 0.00169 -- ~:0: () +++\n", " -- 40300 40300 -- 0.01424 0.01424 -- ~:0: () +++\n", " _create_training_session -- 1 1 -- 0.00001 0.02824 -- onnxcustom/onnxcustom/training/optimizers_partial.py:626:_create_training_session (_create_training_session)\n", " __init__ -- 1 1 -- 0.00008 0.02820 -- onnxcustom/onnxcustom/training/ortgradient.py:54:__init__ (__init__)\n", " -- 1 1 -- 0.00001 0.00001 -- onnxcustom/onnxcustom/training/ortgradient.py:91: ()\n", " -- 1 1 -- 0.00000 0.00000 -- onnxcustom/onnxcustom/training/ortgradient.py:94: ()\n", " -- 1 1 -- 0.00000 0.00000 -- onnxcustom/onnxcustom/training/ortgradient.py:113: ()\n", " _init_next -- 1 1 -- 0.00010 0.02809 -- onnxcustom/onnxcustom/training/ortgradient.py:163:_init_next (_init_next)\n", " -- 1 1 -- 0.00001 0.00001 -- onnxcustom/onnxcustom/training/ortgradient.py:173: ()\n", " -- 1 1 -- 0.00000 0.00000 -- onnxcustom/onnxcustom/training/ortgradient.py:175: ()\n", " -- 1 1 -- 0.00000 0.00000 -- onnxcustom/onnxcustom/training/ortgradient.py:178: ()\n", " _create_onnx_graphs -- 1 1 -- 0.00662 0.02797 -- onnxcustom/onnxcustom/training/ortgradient.py:287:_create_onnx_graphs (_create_onnx_graphs)\n", " -- 1 1 -- 0.00001 0.00001 -- onnxcustom/onnxcustom/training/ortgradient.py:396: ()\n", " -- 1 1 -- 0.00001 0.00001 -- onnxcustom/onnxcustom/training/ortgradient.py:397: ()\n", " -- 1 1 -- 0.00001 0.00002 -- onnxcustom/onnxcustom/training/ortgradient.py:399: ()\n", " _provider_name_to_device_type -- 1 1 -- 0.00000 0.00000 -- onnxcustom/onnxcustom/training/ortgradient.py:260:_provider_name_to_device_type (_provider_name_to_device_type) +++\n", " -- 1 1 -- 0.00002 0.00002 -- onnxcustom/onnxcustom/training/ortgradient.py:404: ()\n", " _provider_name_to_device_type -- 7 7 -- 0.00000 0.00000 -- onnxcustom/onnxcustom/training/ortgradient.py:260:_provider_name_to_device_type (_provider_name_to_device_type) +++\n", " -- 1 1 -- 0.00000 0.00000 -- onnxcustom/onnxcustom/training/ortgradient.py:410: ()\n", " _provider_name_to_device_type -- 1 1 -- 0.00000 0.00000 -- onnxcustom/onnxcustom/training/ortgradient.py:260:_provider_name_to_device_type (_provider_name_to_device_type) +++\n", " -- 1 1 -- 0.00001 0.00001 -- onnxcustom/onnxcustom/training/ortgradient.py:479: ()\n", " -- 1 1 -- 0.00001 0.00001 -- onnxcustom/onnxcustom/training/ortgradient.py:480: ()\n", " get_inputs -- 1 1 -- 0.00000 0.00000 -- onnxruntime/build/Windows/Release/Release/onnxruntime/capi/onnxruntime_inference_collection.py:111:get_inputs (get_inputs)\n", " get_outputs -- 1 1 -- 0.00000 0.00000 -- onnxruntime/build/Windows/Release/Release/onnxruntime/capi/onnxruntime_inference_collection.py:115:get_outputs (get_outputs)\n", " __init__ -- 2 2 -- 0.00004 0.02063 -- onnxruntime/build/Windows/Release/Release/onnxruntime/capi/onnxruntime_inference_collection.py:283:__init__ (__init__)\n", " get -- 2 2 -- 0.00001 0.00004 -- C:/Python395_x64/lib/_collections_abc.py:759:get (get)\n", " __getitem__ -- 2 2 -- 0.00001 0.00003 -- C:/Python395_x64/lib/os.py:674:__getitem__ (__getitem__)\n", " encodekey -- 2 2 -- 0.00001 0.00002 -- C:/Python395_x64/lib/os.py:746:encodekey (encodekey)\n", " check_str -- 2 2 -- 0.00000 0.00000 -- C:/Python395_x64/lib/os.py:740:check_str (check_str)\n", " __init__ -- 2 2 -- 0.00000 0.00000 -- onnxruntime/build/Windows/Release/Release/onnxruntime/capi/onnxruntime_inference_collection.py:101:__init__ (__init__)\n", " _create_inference_session -- 2 2 -- 0.02045 0.02055 -- onnxruntime/build/Windows/Release/Release/onnxruntime/capi/onnxruntime_inference_collection.py:346:_create_inference_session (_create_inference_session)\n", " check_and_nor...rovider_args -- 2 2 -- 0.00004 0.00008 -- onnxruntime/build/Windows/Release/Release/onnxruntime/capi/onnxruntime_inference_collection.py:25:check_and_normalize_provider_args (check_and_normalize_provider_args)\n", " set_provider_options -- 2 2 -- 0.00001 0.00001 -- onnxruntime/build/Windows/Release/Release/onnxruntime/capi/onnxruntime_inference_collection.py:53:set_provider_options (set_provider_options)\n", " -- 2 2 -- 0.00000 0.00000 -- onnxruntime/build/Windows/Release/Release/onnxruntime/capi/onnxruntime_inference_collection.py:62: ()\n", " -- 2 2 -- 0.00000 0.00000 -- onnxruntime/build/Windows/Release/Release/onnxruntime/capi/onnxruntime_inference_collection.py:75: ()\n", " -- 2 2 -- 0.00000 0.00000 -- onnxruntime/build/Windows/Release/Release/onnxruntime/capi/onnxruntime_inference_collection.py:78: ()\n", " load_model -- 2 2 -- 0.00001 0.00049 -- site-packages/onnx/__init__.py:107:load_model (load_model)\n", " _load_bytes -- 2 2 -- 0.00002 0.00003 -- site-packages/onnx/__init__.py:30:_load_bytes (_load_bytes)\n", " inner -- 4 4 -- 0.00000 0.00000 -- C:/Python395_x64/lib/typing.py:262:inner (inner) +++\n", " cast -- 4 4 -- 0.00000 0.00000 -- C:/Python395_x64/lib/typing.py:1333:cast (cast) +++\n", " _get_file_path -- 2 2 -- 0.00000 0.00000 -- site-packages/onnx/__init__.py:50:_get_file_path (_get_file_path)\n", " load_model_from_string -- 2 2 -- 0.00001 0.00045 -- site-packages/onnx/__init__.py:147:load_model_from_string (load_model_from_string)\n", " _deserialize -- 2 2 -- 0.00001 0.00044 -- site-packages/onnx/__init__.py:81:_deserialize (_deserialize)\n", " inner -- 2 2 -- 0.00000 0.00000 -- C:/Python395_x64/lib/typing.py:262:inner (inner) +++\n", " cast -- 2 2 -- 0.00000 0.00000 -- C:/Python395_x64/lib/typing.py:1333:cast (cast) +++\n", " -- 2 2 -- 0.00042 0.00042 -- ~:0: ()\n", " -- 16 16 -- 0.00000 0.00000 -- ~:0: () +++\n", " -- 1 1 -- 0.00014 0.00014 -- ~:0: ()\n", " new_instance -- 1 1 -- 0.00000 0.00000 -- onnxcustom/onnxcustom/training/ortgradient.py:211:new_instance (new_instance)\n", " __init__ -- 1 1 -- 0.00000 0.00000 -- onnxcustom/onnxcustom/training/ortgradient.py:501:__init__ (__init__)\n", " device_to_providers -- 1 1 -- 0.00003 0.00003 -- onnxruntime_helper.py:133:device_to_providers (device_to_providers)\n", " value -- 100 100 -- 0.00003 0.00003 -- onnxcustom/onnxcustom/training/sgd_learning_rate.py:160:value (value) +++\n", " init_learning_rate -- 1 1 -- 0.00000 0.00001 -- onnxcustom/onnxcustom/training/sgd_learning_rate.py:307:init_learning_rate (init_learning_rate)\n", " init_learning_rate -- 1 1 -- 0.00000 0.00000 -- onnxcustom/onnxcustom/training/sgd_learning_rate.py:176:init_learning_rate (init_learning_rate)\n", " update_learning_rate -- 100 100 -- 0.00015 0.00098 -- onnxcustom/onnxcustom/training/sgd_learning_rate.py:314:update_learning_rate (update_learning_rate)\n", " update_learning_rate -- 100 100 -- 0.00084 0.00084 -- onnxcustom/onnxcustom/training/sgd_learning_rate.py:194:update_learning_rate (update_learning_rate)\n", " proto_type_to_dtype -- 6 6 -- 0.00001 0.00001 -- onnxcustom/onnxcustom/utils/onnx_helper.py:53:proto_type_to_dtype (proto_type_to_dtype)\n", " -- 107 107 -- 0.00003 0.00003 -- ~:0: () +++\n", " -- 108 108 -- 0.00002 0.00002 -- ~:0: () +++\n", " -- 6 6 -- 0.00040 0.00040 -- ~:0: ()\n", "inner -- 6 6 -- 0.00001 0.00001 -- C:/Python395_x64/lib/typing.py:262:inner (inner)\n", "cast -- 6 6 -- 0.00000 0.00000 -- C:/Python395_x64/lib/typing.py:1333:cast (cast)\n", "_bio_cache -- 294800 294800 -- 0.18126 0.22052 -- onnxcustom/onnxcustom/training/_base_onnx_function.py:138:_bio_cache (_bio_cache)\n", " -- 294800 294800 -- 0.03926 0.03926 -- ~:0: () +++\n", "_bio_ptr -- 294800 294800 -- 0.20762 0.20762 -- onnxcustom/onnxcustom/training/_base_onnx_function.py:155:_bio_ptr (_bio_ptr)\n", "_bind_input_ortvalue -- 214400 214400 -- 0.21699 0.59239 -- onnxcustom/onnxcustom/training/_base_onnx_function.py:159:_bind_input_ortvalue (_bind_input_ortvalue)\n", " _bio_cache -- 214400 214400 -- 0.13709 0.16646 -- onnxcustom/onnxcustom/training/_base_onnx_function.py:138:_bio_cache (_bio_cache) +++\n", " _bio_do_bind_in -- 14000 14000 -- 0.03012 0.03012 -- onnxcustom/onnxcustom/training/_base_onnx_function.py:151:_bio_do_bind_in (_bio_do_bind_in)\n", " _bio_ptr -- 214400 214400 -- 0.15540 0.15540 -- onnxcustom/onnxcustom/training/_base_onnx_function.py:155:_bio_ptr (_bio_ptr) +++\n", " -- 214400 214400 -- 0.02341 0.02341 -- ~:0: () +++\n", "_get_att_state -- 205 205 -- 0.00007 0.00007 -- onnxcustom/onnxcustom/training/optimizers_partial.py:139:_get_att_state (_get_att_state)\n", "get_full_state -- 101 301 -- 0.00049 0.00093 -- onnxcustom/onnxcustom/training/optimizers_partial.py:147:get_full_state (get_full_state)\n", " _get_att_state -- 201 201 -- 0.00007 0.00007 -- onnxcustom/onnxcustom/training/optimizers_partial.py:139:_get_att_state (_get_att_state) +++\n", " -- 100 100 -- 0.00021 0.00072 -- onnxcustom/onnxcustom/training/optimizers_partial.py:152: ()\n", " get_full_state -- 200 200 -- 0.00030 0.00050 -- onnxcustom/onnxcustom/training/optimizers_partial.py:147:get_full_state (get_full_state) +++\n", " -- 201 201 -- 0.00004 0.00004 -- ~:0: () +++\n", " -- 201 201 -- 0.00005 0.00005 -- ~:0: () +++\n", " -- 301 301 -- 0.00007 0.00007 -- ~:0: () +++\n", "_provider_name_to_device_type -- 9 9 -- 0.00001 0.00001 -- onnxcustom/onnxcustom/training/ortgradient.py:260:_provider_name_to_device_type (_provider_name_to_device_type)\n", "get_initializer -- 14 14 -- 0.00008 0.00175 -- onnxcustom/onnxcustom/training/ortgradient.py:269:get_initializer (get_initializer)\n", " to_array -- 12 12 -- 0.00009 0.00168 -- site-packages/onnx/numpy_helper.py:21:to_array (to_array)\n", " uses_external_data -- 12 12 -- 0.00001 0.00001 -- site-packages/onnx/external_data_helper.py:224:uses_external_data (uses_external_data)\n", " -- 12 12 -- 0.00000 0.00000 -- ~:0: () +++\n", " -- 12 12 -- 0.00006 0.00006 -- ~:0: () +++\n", " -- 12 12 -- 0.00002 0.00002 -- ~:0: () +++\n", " -- 12 12 -- 0.00148 0.00148 -- ~:0: ()\n", " -- 12 12 -- 0.00001 0.00001 -- ~:0: () +++\n", " -- 24 24 -- 0.00001 0.00001 -- ~:0: () +++\n", "input_to_ort -- 13400 13400 -- 0.12150 0.16754 -- onnxcustom/onnxcustom/training/ortgradient.py:552:input_to_ort (input_to_ort)\n", " -- 13400 13400 -- 0.01681 0.03690 -- ~:0: () +++\n", " -- 13400 13400 -- 0.00712 0.00712 -- ~:0: () +++\n", " -- 13400 13400 -- 0.00202 0.00202 -- ~:0: () +++\n", "value -- 40300 40300 -- 0.00955 0.00955 -- onnxcustom/onnxcustom/training/sgd_learning_rate.py:160:value (value)\n", "numpy_to_ort_value -- 13426 13426 -- 0.00661 0.03872 -- onnxruntime_helper.py:120:numpy_to_ort_value (numpy_to_ort_value)\n", " -- 13426 13426 -- 0.03211 0.03211 -- ~:0: () +++\n", " -- 18 18 -- 0.00014 0.00014 -- ~:0: ()\n", " -- 13575 13575 -- 0.00608 0.00608 -- ~:0: ()\n", " -- 94120 94120 -- 0.01981 0.01981 -- ~:0: ()\n", " -- 362251 362251 -- 0.04476 0.04477 -- ~:0: ()\n", " __instancecheck__ -- 4 4 -- 0.00001 0.00001 -- C:/Python395_x64/lib/abc.py:96:__instancecheck__ (__instancecheck__)\n", " -- 67437 67437 -- 0.02341 0.02908 -- ~:0: ()\n", " __len__ -- 13600 13600 -- 0.00567 0.00567 -- onnxcustom/onnxcustom/training/data_loader.py:89:__len__ (__len__)\n", " -- 14 14 -- 0.00002 0.00002 -- ~:0: ()\n", " -- 213 213 -- 0.00005 0.00005 -- ~:0: ()\n", " -- 93826 93826 -- 0.19723 0.19723 -- ~:0: ()\n", " -- 301501 301501 -- 0.04083 0.04083 -- ~:0: ()\n", " -- 36 36 -- 0.00001 0.00001 -- ~:0: ()\n", " -- 13404 13404 -- 0.01681 0.03690 -- ~:0: ()\n", " -- 53600 53600 -- 0.01461 0.02009 -- onnxcustom/onnxcustom/training/ortgradient.py:572: ()\n"]}, {"name": "stdout", "output_type": "stream", "text": [" -- 53600 53600 -- 0.00548 0.00548 -- ~:0: () +++\n"]}], "source": ["def clean_name(text):\n", " pos = text.find('onnxruntime')\n", " if pos >= 0:\n", " return text[pos:]\n", " pos = text.find('sklearn')\n", " if pos >= 0:\n", " return text[pos:]\n", " pos = text.find('onnxcustom')\n", " if pos >= 0:\n", " return text[pos:]\n", " pos = text.find('site-packages')\n", " if pos >= 0:\n", " return text[pos:]\n", " return text\n", "\n", "from pyquickhelper.pycode.profiling import profile, profile2graph\n", "\n", "ps = profile(lambda:train_session2.fit(X, y))[0]\n", "root, nodes = profile2graph(ps, clean_text=clean_name)\n", "text = root.to_text()\n", "print(text)"]}, {"cell_type": "markdown", "id": "ac947d2d", "metadata": {}, "source": ["```\n", " _iteration -- 100 100 -- 0.41903 3.74610 -- \n", " iter_ortvalue -- 6800 6800 -- 0.02838 0.14761 -- \n", " _next_iter -- 6700 6700 -- 0.00946 0.07207 -- \n", " -- 6700 6700 -- 0.00245 0.00423 -- \n", " -- 6700 6700 -- 0.05838 0.05838 -- \n", " numpy_to_ort_value -- 13400 13400 -- 0.00658 0.03860 -- \n", " -- 6900 6900 -- 0.00467 0.00855 -- \n", " forward -- 6700 6700 -- 0.31685 0.44643 -- \n", " input_to_ort -- 6700 6700 -- 0.08002 0.11492 -- \n", " save_for_backward -- 6700 6700 -- 0.01032 0.01032 -- \n", " -- 6700 6700 -- 0.00434 0.00434 -- \n", " backward -- 6700 6700 -- 0.43012 0.48957 -- \n", " input_to_ort -- 6700 6700 -- 0.04148 0.05262 -- \n", " saved_tensors -- 6700 6700 -- 0.00207 0.00207 -- \n", " -- 6700 6700 -- 0.00476 0.00476 -- \n", " loss_gradient -- 6700 6700 -- 0.05841 0.26967 -- \n", " clear_binding_inputs -- 6700 6700 -- 0.00545 0.01270 -- \n", " _cache_in_clear -- 6700 6700 -- 0.00568 0.00725 -- \n", " -- 6700 6700 -- 0.00157 0.00157 -- \n", " _bind_input_ortvalue -- 13400 13400 -- 0.02070 0.07545 -- \n", " _call_iobinding -- 6700 6700 -- 0.11997 0.11997 -- \n", " -- 13400 13400 -- 0.00315 0.00315 -- \n", " penalty_loss -- 6700 6700 -- 0.00112 0.00112 -- \n", " update_weights -- 40200 40200 -- 0.00651 0.00651 -- \n", " update_weights -- 40200 40200 -- 0.40487 1.94238 -- \n", " _bind_input_ortvalue -- 201000 201000 -- 0.19630 0.51693 -- \n", " _bind_output_ortvalue -- 80400 80400 -- 0.07458 0.18952 -- \n", " _bio_cache -- 80400 80400 -- 0.04417 0.05406 -- \n", " _bio_ptr -- 80400 80400 -- 0.05222 0.05222 -- \n", " _bio_do_bind_out -- 12 12 -- 0.00003 0.00003 -- \n", " -- 80400 80400 -- 0.00863 0.00863 -- \n", " _call_iobinding -- 40200 40200 -- 0.63987 0.63987 -- \n", " value -- 40200 40200 -- 0.00953 0.00953 -- \n", " -- 80400 80400 -- 0.16512 0.16512 -- \n", " -- 80400 80400 -- 0.01655 0.01655 -- \n", " -- 100 100 -- 0.00026 0.00426 -- \n", " _mean -- 100 100 -- 0.00163 0.00400 -- \n", " _count_reduce_items -- 100 100 -- 0.00097 0.00107 -- \n", " -- 200 200 -- 0.00010 0.00010 -- \n", " -- 100 100 -- 0.00004 0.00004 -- \n", " -- 100 100 -- 0.00109 0.00109 -- \n", " -- 100 100 -- 0.00006 0.00006 -- \n", " -- 100 100 -- 0.00004 0.00004 -- \n", " -- 200 200 -- 0.00007 0.00007 -- \n", " -- 100 100 -- 0.00358 0.00358 -- \n", " -- 6700 6700 -- 0.00169 0.00169 -- \n", " -- 40300 40300 -- 0.01424 0.01424 -- \n", " _create_training_session -- 1 1 -- 0.00001 0.02824 -- \n", " __init__ -- 1 1 -- 0.00008 0.02820 -- \n", " -- 1 1 -- 0.00001 0.00001 -- \n", " -- 1 1 -- 0.00000 0.00000 -- \n", " -- 1 1 -- 0.00000 0.00000 -- \n", " _init_next -- 1 1 -- 0.00010 0.02809 -- \n", " -- 1 1 -- 0.00001 0.00001 -- \n", " -- 1 1 -- 0.00000 0.00000 -- \n", " -- 1 1 -- 0.00000 0.00000 -- \n", " _create_onnx_graphs -- 1 1 -- 0.00662 0.02797 -- \n", " -- 1 1 -- 0.00001 0.00001 -- \n", " -- 1 1 -- 0.00001 0.00001 -- \n", " -- 1 1 -- 0.00001 0.00002 -- \n", " _provider_name_to_device_type -- 1 1 -- 0.00000 0.00000 -- \n", " -- 1 1 -- 0.00002 0.00002 -- \n", " _provider_name_to_device_type -- 7 7 -- 0.00000 0.00000 -- \n", " -- 1 1 -- 0.00000 0.00000 -- \n", " _provider_name_to_device_type -- 1 1 -- 0.00000 0.00000 -- \n", " -- 1 1 -- 0.00001 0.00001 -- \n", " -- 1 1 -- 0.00001 0.00001 -- \n", " get_inputs -- 1 1 -- 0.00000 0.00000 -- \n", " get_outputs -- 1 1 -- 0.00000 0.00000 -- \n", " __init__ -- 2 2 -- 0.00004 0.02063 -- \n", " get -- 2 2 -- 0.00001 0.00004 -- \n", " __getitem__ -- 2 2 -- 0.00001 0.00003 -- \n", " encodekey -- 2 2 -- 0.00001 0.00002 -- \n", " check_str -- 2 2 -- 0.00000 0.00000 -- \n", " __init__ -- 2 2 -- 0.00000 0.00000 -- \n", " _create_inference_session -- 2 2 -- 0.02045 0.02055 -- \n", " check_and_nor...rovider_args -- 2 2 -- 0.00004 0.00008 -- \n", " set_provider_options -- 2 2 -- 0.00001 0.00001 -- \n", " -- 2 2 -- 0.00000 0.00000 -- \n", " -- 2 2 -- 0.00000 0.00000 -- \n", " -- 2 2 -- 0.00000 0.00000 -- \n", " load_model -- 2 2 -- 0.00001 0.00049 -- \n", " _load_bytes -- 2 2 -- 0.00002 0.00003 -- \n", " inner -- 4 4 -- 0.00000 0.00000 -- \n", " cast -- 4 4 -- 0.00000 0.00000 -- \n", " _get_file_path -- 2 2 -- 0.00000 0.00000 -- \n", " load_model_from_string -- 2 2 -- 0.00001 0.00045 -- \n", " _deserialize -- 2 2 -- 0.00001 0.00044 -- \n", " inner -- 2 2 -- 0.00000 0.00000 -- \n", " cast -- 2 2 -- 0.00000 0.00000 -- \n", " -- 2 2 -- 0.00042 0.00042 -- \n", " -- 16 16 -- 0.00000 0.00000 -- \n", " -- 1 1 -- 0.00014 0.00014 -- \n", " new_instance -- 1 1 -- 0.00000 0.00000 -- \n", " __init__ -- 1 1 -- 0.00000 0.00000 -- \n", " device_to_providers -- 1 1 -- 0.00003 0.00003 -- \n", " value -- 100 100 -- 0.00003 0.00003 -- \n", "\n", "```"]}, {"cell_type": "code", "execution_count": 49, "id": "af74c9ce", "metadata": {}, "outputs": [{"data": {"text/plain": ["{'model_onnx': 'mlp_onnx_ort\\\\GradFBOptimizer.model_onnx.onnx',\n", " 'learning_rate': {'axpyw_onnx_': 'mlp_onnx_ort\\\\LRateSGDNesterov.learning_rate.axpyw_onnx_.onnx'},\n", " 'learning_loss': {'loss_grad_onnx_': 'mlp_onnx_ort\\\\SquareLLoss.learning_loss.loss_grad_onnx_.onnx',\n", " 'loss_score_onnx_': 'mlp_onnx_ort\\\\SquareLLoss.learning_loss.loss_score_onnx_.onnx'},\n", " 'learning_penalty': {},\n", " 'zero_onnx_': 'mlp_onnx_ort\\\\GradFBOptimizer.zero_onnx_.onnx',\n", " 'train_function_': {'_trained_onnx': 'mlp_onnx_ort\\\\OrtGradientForwardBackwardFunction_1523278698000.train_function_._trained_onnx.onnx',\n", " '_optimized_pre_grad_model': 'mlp_onnx_ort\\\\OrtGradientForwardBackwardFunction_1523278698000.train_function_._optimized_pre_grad_model.onnx'}}"]}, "execution_count": 50, "metadata": {}, "output_type": "execute_result"}], "source": ["import os\n", "if not os.path.exists(\"mlp_onnx_ort\"):\n", " os.mkdir(\"mlp_onnx_ort\")\n", "train_session2.save_onnx_graph(\"mlp_onnx_ort\")"]}, {"cell_type": "markdown", "id": "1fce825b", "metadata": {}, "source": ["Weights are updated with the following ONNX graph:"]}, {"cell_type": "code", "execution_count": 50, "id": "bb84a7c6", "metadata": {}, "outputs": [{"data": {"text/html": ["
\n", ""], "text/plain": [""]}, "execution_count": 51, "metadata": {}, "output_type": "execute_result"}], "source": ["%onnxview train_session2.learning_rate.axpyw_onnx_"]}, {"cell_type": "code", "execution_count": 51, "id": "85bb6294", "metadata": {}, "outputs": [], "source": []}], "metadata": {"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}, "language_info": {"codemirror_mode": {"name": "ipython", "version": 3}, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.5"}}, "nbformat": 4, "nbformat_minor": 5}