{"cells": [{"cell_type": "markdown", "id": "61c4a48d", "metadata": {}, "source": ["# Stochastic Gradient Descent on simple function\n", "\n", "[onnxruntime-training](https://github.com/microsoft/onnxruntime) is an extension onnxruntime or more precisely the same library compiled with different settings. It provides a way to compute a gradient of a function defined by an ONNX graph."]}, {"cell_type": "code", "execution_count": 1, "id": "cb34ae6d", "metadata": {}, "outputs": [{"data": {"text/html": ["
run previous cell, wait for 2 seconds
\n", ""], "text/plain": [""]}, "execution_count": 2, "metadata": {}, "output_type": "execute_result"}], "source": ["from jyquickhelper import add_notebook_menu\n", "add_notebook_menu()"]}, {"cell_type": "code", "execution_count": 2, "id": "ef1a1dfa", "metadata": {}, "outputs": [], "source": ["%load_ext mlprodict"]}, {"cell_type": "markdown", "id": "e39f1613", "metadata": {}, "source": ["## A simple problem\n", "\n", "Let's choose a simple regression problem defined by $z = -1 - 2x + 3y + \\frac{1}{2}x^2 -\\frac{1}{3} y^2 +\\epsilon$ and we try to approximate by a function $f(x,y) = a + bx + cy + dx^2 + ey^2$. Every coefficient is determined from an optimization problem solved with a stochastic gradient descent."]}, {"cell_type": "code", "execution_count": 3, "id": "f50d8a3d", "metadata": {"scrolled": false}, "outputs": [{"data": {"text/plain": ["array([[-1. ],\n", " [-2.5 ],\n", " [ 1.6666701 ],\n", " [ 0.16667008],\n", " [ 1.6666799 ]], dtype=float32)"]}, "execution_count": 4, "metadata": {}, "output_type": "execute_result"}], "source": ["from typing import Any\n", "import numpy\n", "import mlprodict.npy.numpy_onnx_impl as npnx\n", "from mlprodict.npy import onnxnumpy_default, NDArray\n", "\n", "\n", "@onnxnumpy_default\n", "def fct(x: NDArray[(None, 2), numpy.float32]) -> NDArray[(None, 1), numpy.float32]:\n", " coef_x = numpy.array([[-2, 3]], dtype=numpy.float32) \n", " coef_x2 = numpy.array([[0.5, -0.33333]], dtype=numpy.float32)\n", " bias = numpy.array([-1], dtype=numpy.float32)\n", " poly = x * coef_x + x * x * coef_x2\n", " y = poly[:, 0] + poly[:, 1] + bias\n", " return y.reshape((-1, 1))\n", "\n", "\n", "x = numpy.array([[0, 0], [1, 0], [0, 1], [1, 1], [2, 2]], dtype=numpy.float32)\n", "fct(x)"]}, {"cell_type": "code", "execution_count": 4, "id": "0e3735c0", "metadata": {}, "outputs": [{"data": {"text/html": ["
\n", ""], "text/plain": [""]}, "execution_count": 5, "metadata": {}, "output_type": "execute_result"}], "source": ["%onnxview fct.to_onnx()"]}, {"cell_type": "code", "execution_count": 5, "id": "f5c35b29", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["opset: domain='' version=14\n", "input: name='x' type=dtype('float32') shape=(0, 2)\n", "init: name='init' type=dtype('float32') shape=(0,) -- array([ 0.5 , -0.33333], dtype=float32)\n", "init: name='init_1' type=dtype('float32') shape=(0,) -- array([-2., 3.], dtype=float32)\n", "init: name='init_2' type=dtype('int64') shape=(0,) -- array([1], dtype=int64)\n", "init: name='init_3' type=dtype('int64') shape=(0,) -- array([2], dtype=int64)\n", "init: name='init_5' type=dtype('int64') shape=(0,) -- array([0], dtype=int64)\n", "init: name='init_b10' type=dtype('float32') shape=(0,) -- array([-1.], dtype=float32)\n", "init: name='init_b11' type=dtype('int64') shape=(0,) -- array([-1, 1], dtype=int64)\n", "Mul(x, x) -> out_mul_0\n", " Mul(out_mul_0, init) -> out_mul_0_1\n", "Mul(x, init_1) -> out_mul_0_2\n", " Add(out_mul_0_2, out_mul_0_1) -> out_add_0\n", " Slice(out_add_0, init_2, init_3, init_2) -> out_sli_0\n", " Squeeze(out_sli_0, init_2) -> out_squ_0\n", " Slice(out_add_0, init_5, init_2, init_2) -> out_sli_0_1\n", " Squeeze(out_sli_0_1, init_2) -> out_squ_0_1\n", " Add(out_squ_0_1, out_squ_0) -> out_add_0_1\n", " Add(out_add_0_1, init_b10) -> out_add_0_2\n", " Reshape(out_add_0_2, init_b11) -> y\n", "output: name='y' type=dtype('float32') shape=(0, 1)\n"]}], "source": ["from mlprodict.plotting.text_plot import onnx_simple_text_plot\n", "print(onnx_simple_text_plot(fct.to_onnx()))"]}, {"cell_type": "markdown", "id": "b56814b7", "metadata": {}, "source": ["## Gradient : retropropagation\n", "\n", "Let's look into the gradient."]}, {"cell_type": "code", "execution_count": 6, "id": "695682f5", "metadata": {}, "outputs": [{"data": {"text/html": ["
\n", ""], "text/plain": [""]}, "execution_count": 7, "metadata": {}, "output_type": "execute_result"}], "source": ["from onnxcustom.training.grad_helper import onnx_derivative, DerivativeOptions\n", "\n", "onx = fct.to_onnx()\n", "grad = onnx_derivative(onx)\n", "%onnxview grad"]}, {"cell_type": "code", "execution_count": 7, "id": "72d51122", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["opset: domain='' version=14\n", "opset: domain='com.microsoft.nchwc' version=1\n", "opset: domain='ai.onnx.ml' version=2\n", "opset: domain='com.ms.internal.nhwc' version=1\n", "opset: domain='ai.onnx.training' version=1\n", "opset: domain='ai.onnx.preview.training' version=1\n", "opset: domain='com.microsoft' version=1\n", "opset: domain='com.microsoft.experimental' version=1\n", "input: name='x' type=dtype('float32') shape=(0, 2)\n", "input: name='init' type=dtype('float32') shape=(1, 2)\n", "input: name='init_1' type=dtype('float32') shape=(1, 2)\n", "input: name='init_b10' type=dtype('float32') shape=(1,)\n", "input: name='y_grad' type=dtype('float32') shape=(0, 1)\n", "init: name='init_5' type=dtype('int64') shape=(0,)\n", "init: name='init_2' type=dtype('int64') shape=(0,)\n", "init: name='init_3' type=dtype('int64') shape=(0,)\n", "output: name='x_grad' type=dtype('float32') shape=(0, 2)\n", "output: name='init_grad' type=dtype('float32') shape=(1, 2)\n", "output: name='init_1_grad' type=dtype('float32') shape=(1, 2)\n", "output: name='init_b10_grad' type=dtype('float32') shape=(1,)\n"]}], "source": ["from mlprodict.plotting.text_plot import onnx_text_plot_io, onnx_simple_text_plot\n", "print(onnx_text_plot_io(grad))"]}, {"cell_type": "code", "execution_count": 8, "id": "753199ee", "metadata": {}, "outputs": [], "source": ["from mlprodict.onnx_tools.onnx_manipulations import onnx_rename_names\n", "renamed = onnx_rename_names(grad)"]}, {"cell_type": "code", "execution_count": 9, "id": "2d0b3a76", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["opset: domain='' version=14\n", "opset: domain='com.microsoft.nchwc' version=1\n", "opset: domain='ai.onnx.ml' version=2\n", "opset: domain='com.ms.internal.nhwc' version=1\n", "opset: domain='ai.onnx.training' version=1\n", "opset: domain='ai.onnx.preview.training' version=1\n", "opset: domain='com.microsoft' version=1\n", "opset: domain='com.microsoft.experimental' version=1\n", "input: name='x' type=dtype('float32') shape=(0, 2)\n", "input: name='init' type=dtype('float32') shape=(1, 2)\n", "input: name='init_1' type=dtype('float32') shape=(1, 2)\n", "input: name='init_b10' type=dtype('float32') shape=(1,)\n", "input: name='y_grad' type=dtype('float32') shape=(0, 1)\n", "init: name='i0' type=dtype('int64') shape=(0,) -- array([0], dtype=int64)\n", "init: name='i1' type=dtype('int64') shape=(0,) -- array([1], dtype=int64)\n", "init: name='i2' type=dtype('int64') shape=(0,) -- array([2], dtype=int64)\n", "Mul(x, x) -> r0\n", " Mul(r0, init) -> r1\n", " Shape(r1) -> r32\n", "Mul(x, init_1) -> r2\n", " Add(r2, r1) -> r3\n", " Slice(r3, i1, i2, i1) -> r4\n", " Squeeze(r4, i1) -> r5\n", " Shape(r5) -> r18\n", " Slice(r3, i0, i1, i1) -> r6\n", " Squeeze(r6, i1) -> r7\n", " Add(r7, r5) -> r8\n", " Add(r8, init_b10) -> r9\n", " Shape(r9) -> r10\n", " Reshape(y_grad, r10, allowzero=0) -> r11\n", "Shape(init_b10) -> r12\n", "Shape(r8) -> r13\n", " BroadcastGradientArgs(r13, r12) -> r14, r15\n", " ReduceSum(r11, r14, keepdims=1, noop_with_empty_axes=1) -> r16\n", " Reshape(r16, r13, allowzero=0) -> r17\n", "Shape(r7) -> r19\n", " BroadcastGradientArgs(r19, r18) -> r20, r21\n", " ReduceSum(r17, r21, keepdims=1, noop_with_empty_axes=1) -> r22\n", " Reshape(r22, r18, allowzero=0) -> r23\n", " Unsqueeze(r23, i1) -> r24\n", " Shape(r3) -> r25\n", " SliceGrad(r24, r25, i1, i2, i1) -> r26\n", " ReduceSum(r17, r20, keepdims=1, noop_with_empty_axes=1) -> r27\n", " Reshape(r27, r19, allowzero=0) -> r28\n", " Unsqueeze(r28, i1) -> r29\n", " SliceGrad(r29, r25, i0, i1, i1) -> r30\n", " Sum(r30, r26) -> r31\n", " Shape(r2) -> r33\n", " BroadcastGradientArgs(r33, r32) -> r34, r35\n", " ReduceSum(r31, r35, keepdims=1, noop_with_empty_axes=1) -> r36\n", " Reshape(r36, r32, allowzero=0) -> r37\n", " Mul(r37, init) -> r38\n", "Shape(init) -> r39\n", "Shape(r0) -> r40\n", " BroadcastGradientArgs(r40, r39) -> r41, r42\n", " ReduceSum(r38, r41, keepdims=1, noop_with_empty_axes=1) -> r43\n", " Reshape(r43, r40, allowzero=0) -> r44\n", " Mul(r44, x) -> r45\n", "ReduceSum(r31, r34, keepdims=1, noop_with_empty_axes=1) -> r46\n", " Reshape(r46, r33, allowzero=0) -> r47\n", " Mul(r47, init_1) -> r48\n", "Shape(init_1) -> r49\n", "Shape(x) -> r50\n", " BroadcastGradientArgs(r50, r49) -> r51, r52\n", " ReduceSum(r48, r51, keepdims=1, noop_with_empty_axes=1) -> r53\n", " Reshape(r53, r50, allowzero=0) -> r54\n", " Sum(r54, r45, r45) -> x_grad\n", "ReduceSum(r11, r15, keepdims=1, noop_with_empty_axes=1) -> r55\n", " Reshape(r55, r12, allowzero=0) -> init_b10_grad\n", "Mul(r37, r0) -> r56\n", " ReduceSum(r56, r42, keepdims=1, noop_with_empty_axes=1) -> r57\n", " Reshape(r57, r39, allowzero=0) -> init_grad\n", "Mul(r47, x) -> r58\n", " ReduceSum(r58, r52, keepdims=1, noop_with_empty_axes=1) -> r59\n", " Reshape(r59, r49, allowzero=0) -> init_1_grad\n", "output: name='x_grad' type=dtype('float32') shape=(0, 2)\n", "output: name='init_grad' type=dtype('float32') shape=(1, 2)\n", "output: name='init_1_grad' type=dtype('float32') shape=(1, 2)\n", "output: name='init_b10_grad' type=dtype('float32') shape=(1,)\n"]}], "source": ["print(onnx_simple_text_plot(renamed))"]}, {"cell_type": "code", "execution_count": 10, "id": "d0f09fbc", "metadata": {}, "outputs": [{"data": {"text/plain": ["{'Add',\n", " 'BroadcastGradientArgs',\n", " 'Mul',\n", " 'ReduceSum',\n", " 'Reshape',\n", " 'Shape',\n", " 'Slice',\n", " 'SliceGrad',\n", " 'Squeeze',\n", " 'Sum',\n", " 'Unsqueeze'}"]}, "execution_count": 11, "metadata": {}, "output_type": "execute_result"}], "source": ["set(n.op_type for n in grad.graph.node)"]}, {"cell_type": "markdown", "id": "6e3558bf", "metadata": {}, "source": ["The resulting graph assumes the gradient for `y_grad` is known. That's the case for a layer in a neural network. In our case, this gradient should come from the loss. Let's add it to the graph."]}, {"cell_type": "markdown", "id": "701eead0", "metadata": {}, "source": ["## Add a square loss"]}, {"cell_type": "code", "execution_count": 11, "id": "4d885b0d", "metadata": {}, "outputs": [{"data": {"text/html": ["
\n", ""], "text/plain": [""]}, "execution_count": 12, "metadata": {}, "output_type": "execute_result"}], "source": ["from onnxcustom.utils.orttraining_helper import add_loss_output\n", "onx_loss = add_loss_output(onx)\n", "\n", "%onnxview onx_loss"]}, {"cell_type": "code", "execution_count": 12, "id": "6b451748", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["opset: domain='' version=14\n", "input: name='x' type=dtype('float32') shape=(0, 2)\n", "input: name='label' type=dtype('float32') shape=(0, 1)\n", "init: name='init' type=dtype('float32') shape=(0,) -- array([ 0.5 , -0.33333], dtype=float32)\n", "init: name='init_1' type=dtype('float32') shape=(0,) -- array([-2., 3.], dtype=float32)\n", "init: name='init_2' type=dtype('int64') shape=(0,) -- array([1], dtype=int64)\n", "init: name='init_3' type=dtype('int64') shape=(0,) -- array([2], dtype=int64)\n", "init: name='init_5' type=dtype('int64') shape=(0,) -- array([0], dtype=int64)\n", "init: name='init_b10' type=dtype('float32') shape=(0,) -- array([-1.], dtype=float32)\n", "init: name='init_b11' type=dtype('int64') shape=(0,) -- array([-1, 1], dtype=int64)\n", "Mul(x, x) -> out_mul_0\n", " Mul(out_mul_0, init) -> out_mul_0_1\n", "Mul(x, init_1) -> out_mul_0_2\n", " Add(out_mul_0_2, out_mul_0_1) -> out_add_0\n", " Slice(out_add_0, init_2, init_3, init_2) -> out_sli_0\n", " Squeeze(out_sli_0, init_2) -> out_squ_0\n", " Slice(out_add_0, init_5, init_2, init_2) -> out_sli_0_1\n", " Squeeze(out_sli_0_1, init_2) -> out_squ_0_1\n", " Add(out_squ_0_1, out_squ_0) -> out_add_0_1\n", " Add(out_add_0_1, init_b10) -> out_add_0_2\n", " Reshape(out_add_0_2, init_b11) -> y\n", " Sub(y, label) -> loss_diff\n", " Mul(loss_diff, loss_diff) -> loss_diff_2\n", " ReduceSum(loss_diff_2) -> loss\n", "output: name='loss' type=dtype('float32') shape=(1, 1)\n", "output: name='y' type=dtype('float32') shape=(0, 1)\n"]}], "source": ["print(onnx_simple_text_plot(onx_loss))"]}, {"cell_type": "markdown", "id": "db99d22e", "metadata": {}, "source": ["The graph has 5 inputs: `x`, `label` or the expected target, and the weights and two outputs, the function output and the loss. We don't need the first one so we remove it."]}, {"cell_type": "code", "execution_count": 13, "id": "15ca92ad", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["opset: domain='' version=14\n", "input: name='x' type=dtype('float32') shape=(0, 2)\n", "input: name='label' type=dtype('float32') shape=(0, 1)\n", "init: name='init' type=dtype('float32') shape=(0,) -- array([ 0.5 , -0.33333], dtype=float32)\n", "init: name='init_1' type=dtype('float32') shape=(0,) -- array([-2., 3.], dtype=float32)\n", "init: name='init_2' type=dtype('int64') shape=(0,) -- array([1], dtype=int64)\n", "init: name='init_3' type=dtype('int64') shape=(0,) -- array([2], dtype=int64)\n", "init: name='init_5' type=dtype('int64') shape=(0,) -- array([0], dtype=int64)\n", "init: name='init_b10' type=dtype('float32') shape=(0,) -- array([-1.], dtype=float32)\n", "init: name='init_b11' type=dtype('int64') shape=(0,) -- array([-1, 1], dtype=int64)\n", "Mul(x, x) -> out_mul_0\n", " Mul(out_mul_0, init) -> out_mul_0_1\n", "Mul(x, init_1) -> out_mul_0_2\n", " Add(out_mul_0_2, out_mul_0_1) -> out_add_0\n", " Slice(out_add_0, init_5, init_2, init_2) -> out_sli_0_1\n", " Squeeze(out_sli_0_1, init_2) -> out_squ_0_1\n", " Slice(out_add_0, init_2, init_3, init_2) -> out_sli_0\n", " Squeeze(out_sli_0, init_2) -> out_squ_0\n", " Add(out_squ_0_1, out_squ_0) -> out_add_0_1\n", " Add(out_add_0_1, init_b10) -> out_add_0_2\n", " Reshape(out_add_0_2, init_b11) -> y\n", " Sub(y, label) -> loss_diff\n", " Mul(loss_diff, loss_diff) -> loss_diff_2\n", " ReduceSum(loss_diff_2) -> loss\n", "output: name='loss' type=dtype('float32') shape=(1, 1)\n"]}], "source": ["from mlprodict.onnx_tools.onnx_manipulations import select_model_inputs_outputs\n", "\n", "onx_loss_only = select_model_inputs_outputs(onx_loss, outputs=['loss'])\n", "print(onnx_simple_text_plot(onx_loss_only))"]}, {"cell_type": "markdown", "id": "3d3a5477", "metadata": {}, "source": ["## Gradient again : loss + retropropagation"]}, {"cell_type": "code", "execution_count": 14, "id": "0496baf6", "metadata": {"scrolled": false}, "outputs": [{"data": {"text/html": ["
\n", ""], "text/plain": [""]}, "execution_count": 15, "metadata": {}, "output_type": "execute_result"}], "source": ["grad_loss = onnx_rename_names(onnx_derivative(\n", " onx_loss_only, options=DerivativeOptions.FillGrad | DerivativeOptions.KeepOutputs))\n", "%onnxview grad_loss"]}, {"cell_type": "code", "execution_count": 15, "id": "e497697d", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["opset: domain='' version=14\n", "opset: domain='com.microsoft.nchwc' version=1\n", "opset: domain='ai.onnx.ml' version=2\n", "opset: domain='com.ms.internal.nhwc' version=1\n", "opset: domain='ai.onnx.training' version=1\n", "opset: domain='ai.onnx.preview.training' version=1\n", "opset: domain='com.microsoft' version=1\n", "opset: domain='com.microsoft.experimental' version=1\n", "input: name='x' type=dtype('float32') shape=(0, 2)\n", "input: name='label' type=dtype('float32') shape=(0, 1)\n", "input: name='init' type=dtype('float32') shape=(1, 2)\n", "input: name='init_1' type=dtype('float32') shape=(1, 2)\n", "input: name='init_b10' type=dtype('float32') shape=(1,)\n", "init: name='i0' type=dtype('int64') shape=(0,) -- array([-1, 1], dtype=int64)\n", "init: name='i1' type=dtype('int64') shape=(0,) -- array([0], dtype=int64)\n", "init: name='i2' type=dtype('int64') shape=(0,) -- array([1], dtype=int64)\n", "init: name='i3' type=dtype('int64') shape=(0,) -- array([2], dtype=int64)\n", "Mul(x, init_1) -> r0\n", " Shape(r0) -> r47\n", "Mul(x, x) -> r1\n", " Mul(r1, init) -> r2\n", " Add(r0, r2) -> r3\n", " Slice(r3, i1, i2, i2) -> r4\n", " Squeeze(r4, i2) -> r5\n", " Shape(r5) -> r33\n", " Slice(r3, i2, i3, i2) -> r6\n", " Squeeze(r6, i2) -> r7\n", " Add(r5, r7) -> r8\n", " Add(r8, init_b10) -> r9\n", " Reshape(r9, i0, allowzero=0) -> r10\n", " Sub(r10, label) -> r11\n", " Mul(r11, r11) -> r12\n", " ReduceSum(r12, keepdims=1, noop_with_empty_axes=0) -> loss\n", " Shape(loss) -> r76\n", " ConstantOfShape(r76) -> r14\n", " Shape(r12) -> r13\n", " Expand(r14, r13) -> r15\n", " Mul(r15, r11) -> r16\n", " Sum(r16, r16) -> r17\n", "Shape(label) -> r18\n", "Shape(r10) -> r19\n", " BroadcastGradientArgs(r19, r18) -> r20, r21\n", " ReduceSum(r17, r20, keepdims=1, noop_with_empty_axes=1) -> r22\n", " Reshape(r22, r19, allowzero=0) -> r23\n", "Shape(r9) -> r24\n", " Reshape(r23, r24, allowzero=0) -> r25\n", "Shape(init_b10) -> r26\n", "Shape(r8) -> r27\n", " BroadcastGradientArgs(r27, r26) -> r28, r29\n", " ReduceSum(r25, r28, keepdims=1, noop_with_empty_axes=1) -> r30\n", " Reshape(r30, r27, allowzero=0) -> r31\n", "Shape(r7) -> r32\n", " BroadcastGradientArgs(r33, r32) -> r34, r35\n", " ReduceSum(r31, r34, keepdims=1, noop_with_empty_axes=1) -> r36\n", " Reshape(r36, r33, allowzero=0) -> r37\n", " Unsqueeze(r37, i2) -> r38\n", " Shape(r3) -> r39\n", " SliceGrad(r38, r39, i1, i2, i2) -> r40\n", " ReduceSum(r31, r35, keepdims=1, noop_with_empty_axes=1) -> r41\n", " Reshape(r41, r32, allowzero=0) -> r42\n", " Unsqueeze(r42, i2) -> r43\n", " SliceGrad(r43, r39, i2, i3, i2) -> r44\n", " Sum(r44, r40) -> r45\n", " Shape(r2) -> r46\n", " BroadcastGradientArgs(r47, r46) -> r48, r49\n", " ReduceSum(r45, r48, keepdims=1, noop_with_empty_axes=1) -> r50\n", " Reshape(r50, r47, allowzero=0) -> r51\n", " Mul(r51, init_1) -> r52\n", "Shape(init_1) -> r53\n", "Shape(x) -> r54\n", " BroadcastGradientArgs(r54, r53) -> r55, r56\n", " ReduceSum(r52, r55, keepdims=1, noop_with_empty_axes=1) -> r57\n", " Reshape(r57, r54, allowzero=0) -> r58\n", "ReduceSum(r45, r49, keepdims=1, noop_with_empty_axes=1) -> r59\n", " Reshape(r59, r46, allowzero=0) -> r60\n", " Mul(r60, init) -> r61\n", "Shape(init) -> r62\n", "Shape(r1) -> r63\n", " BroadcastGradientArgs(r63, r62) -> r64, r65\n", " ReduceSum(r61, r64, keepdims=1, noop_with_empty_axes=1) -> r66\n", " Reshape(r66, r63, allowzero=0) -> r67\n", " Mul(r67, x) -> r68\n", " Sum(r68, r68, r58) -> x_grad\n", "ReduceSum(r17, r21, keepdims=1, noop_with_empty_axes=1) -> r69\n", " Reshape(r69, r18, allowzero=0) -> r70\n", " Neg(r70) -> label_grad\n", "ReduceSum(r25, r29, keepdims=1, noop_with_empty_axes=1) -> r71\n", " Reshape(r71, r26, allowzero=0) -> init_b10_grad\n", "Mul(r51, x) -> r72\n", " ReduceSum(r72, r56, keepdims=1, noop_with_empty_axes=1) -> r73\n", " Reshape(r73, r53, allowzero=0) -> init_1_grad\n", "Mul(r60, r1) -> r74\n", " ReduceSum(r74, r65, keepdims=1, noop_with_empty_axes=1) -> r75\n", " Reshape(r75, r62, allowzero=0) -> init_grad\n", "output: name='x_grad' type=dtype('float32') shape=(0, 2)\n", "output: name='label_grad' type=dtype('float32') shape=(0, 1)\n", "output: name='init_grad' type=dtype('float32') shape=(1, 2)\n", "output: name='init_1_grad' type=dtype('float32') shape=(1, 2)\n", "output: name='init_b10_grad' type=dtype('float32') shape=(1,)\n", "output: name='loss' type=dtype('float32') shape=(1, 1)\n"]}], "source": ["print(onnx_simple_text_plot(grad_loss))"]}, {"cell_type": "markdown", "id": "ba52a6bd", "metadata": {}, "source": ["Let's compute the gradient."]}, {"cell_type": "code", "execution_count": 16, "id": "8fdbf27e", "metadata": {}, "outputs": [{"data": {"text/plain": ["array([[0., 0.],\n", " [1., 0.],\n", " [0., 1.],\n", " [1., 1.],\n", " [2., 2.]], dtype=float32)"]}, "execution_count": 17, "metadata": {}, "output_type": "execute_result"}], "source": ["x"]}, {"cell_type": "code", "execution_count": 17, "id": "e7104f1c", "metadata": {}, "outputs": [{"data": {"text/plain": ["array([[-1. ],\n", " [-2.5 ],\n", " [ 1.6666701 ],\n", " [ 0.16667008],\n", " [ 1.6666799 ]], dtype=float32)"]}, "execution_count": 18, "metadata": {}, "output_type": "execute_result"}], "source": ["y = fct(x)\n", "y"]}, {"cell_type": "code", "execution_count": 18, "id": "3f7240ef", "metadata": {"scrolled": false}, "outputs": [], "source": ["from mlprodict.onnxrt import OnnxInference\n", "\n", "oinf = OnnxInference(grad_loss, runtime='onnxruntime1')"]}, {"cell_type": "code", "execution_count": 19, "id": "550a9073", "metadata": {"scrolled": false}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["{'init_1_grad': array([[109.333244, 102.666565]], dtype=float32),\n", " 'init_b10_grad': array([76.6666], dtype=float32),\n", " 'init_grad': array([[193.33316, 186.66649]], dtype=float32),\n", " 'label_grad': array([[ -4. ],\n", " [-12. ],\n", " [ -5.33332 ],\n", " [-13.333321],\n", " [-41.99996 ]], dtype=float32),\n", " 'loss': array([[532.5546]], dtype=float32),\n", " 'x_grad': array([[ 2. , 1.33332 ],\n", " [ 54. , 3.99996 ],\n", " [ 2.66666 , 33.777676],\n", " [ 59.999943, 84.44432 ],\n", " [356.99966 , 517.9994 ]], dtype=float32)}\n"]}], "source": ["import pprint\n", "\n", "init = numpy.array([[2, 3]], dtype=numpy.float32)\n", "init_1 = numpy.array([[0.5, 0.33333]], dtype=numpy.float32)\n", "init_b10 = numpy.array([1], dtype=numpy.float32)\n", "result = oinf.run({'x': x, 'label': y, \n", " 'init': init, 'init_1': init_1, 'init_b10': init_b10})\n", "pprint.pprint(result)"]}, {"cell_type": "markdown", "id": "211dec4a", "metadata": {}, "source": ["We could use this gradient to implement a stochastic gradient descent in python. Two comments:\n", "* If we implement it this with numpy, it cannot work on GPU.\n", "* If we use OrtValue (tensor from onnxruntime), how to do simple addition between OrtValue ?\n", "\n", "We need to implemented the second option. A simple addition between two OrtValue must be done with an ONNX graph."]}, {"cell_type": "markdown", "id": "961435d5", "metadata": {}, "source": ["## TrainingSession"]}, {"cell_type": "code", "execution_count": 20, "id": "18fd1c4d", "metadata": {}, "outputs": [{"data": {"text/plain": ["((100, 2), (100, 1))"]}, "execution_count": 21, "metadata": {}, "output_type": "execute_result"}], "source": ["X = numpy.random.randn(100, 2).astype(numpy.float32) / 10\n", "y = fct(X) + (numpy.random.randn(100, 1) / 1000).astype(numpy.float32)\n", "X.shape, y.shape"]}, {"cell_type": "code", "execution_count": 21, "id": "60e1c260", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["opset: domain='' version=14\n", "input: name='x' type=dtype('float32') shape=(0, 2)\n", "init: name='init' type=dtype('float32') shape=(0,) -- array([ 0.5 , -0.33333], dtype=float32)\n", "init: name='init_1' type=dtype('float32') shape=(0,) -- array([-2., 3.], dtype=float32)\n", "init: name='init_2' type=dtype('int64') shape=(0,) -- array([1], dtype=int64)\n", "init: name='init_3' type=dtype('int64') shape=(0,) -- array([2], dtype=int64)\n", "init: name='init_5' type=dtype('int64') shape=(0,) -- array([0], dtype=int64)\n", "init: name='init_b10' type=dtype('float32') shape=(0,) -- array([-1.], dtype=float32)\n", "init: name='init_b11' type=dtype('int64') shape=(0,) -- array([-1, 1], dtype=int64)\n", "Mul(x, x) -> out_mul_0\n", " Mul(out_mul_0, init) -> out_mul_0_1\n", "Mul(x, init_1) -> out_mul_0_2\n", " Add(out_mul_0_2, out_mul_0_1) -> out_add_0\n", " Slice(out_add_0, init_2, init_3, init_2) -> out_sli_0\n", " Squeeze(out_sli_0, init_2) -> out_squ_0\n", " Slice(out_add_0, init_5, init_2, init_2) -> out_sli_0_1\n", " Squeeze(out_sli_0_1, init_2) -> out_squ_0_1\n", " Add(out_squ_0_1, out_squ_0) -> out_add_0_1\n", " Add(out_add_0_1, init_b10) -> out_add_0_2\n", " Reshape(out_add_0_2, init_b11) -> y\n", "output: name='y' type=dtype('float32') shape=(0, 1)\n"]}], "source": ["print(onnx_simple_text_plot(onx))"]}, {"cell_type": "code", "execution_count": 22, "id": "f6bfc26e", "metadata": {}, "outputs": [{"data": {"text/plain": ["OrtGradientOptimizer(model_onnx='ir_version...', weights_to_train=['init', 'init_1', 'init_b10'], loss_output_name='loss', max_iter=100, training_optimizer_name='SGDOptimizer', batch_size=5, learning_rate=LearningRateSGD(eta0=0.1, alpha=0.0001, power_t=0.25, learning_rate='invscaling'), value=0.03162277660168379, device='cpu', warm_start=False, verbose=0, validation_every=10, saved_gradient=None, sample_weight_name='weight')"]}, "execution_count": 23, "metadata": {}, "output_type": "execute_result"}], "source": ["from onnxcustom.training.optimizers import OrtGradientOptimizer\n", "\n", "train_session = OrtGradientOptimizer(\n", " onx_loss, ['init', 'init_1', 'init_b10'], learning_rate=1e-1,\n", " batch_size=5, max_iter=100)\n", "\n", "train_session.fit(X, y)"]}, {"cell_type": "code", "execution_count": 23, "id": "5880b69f", "metadata": {}, "outputs": [{"data": {"text/plain": ["{'init': array([[-0.34785354, 1.1399053 ]], dtype=float32),\n", " 'init_1': array([[-1.9156165, 2.4292002]], dtype=float32),\n", " 'init_b10': array([-1.0016667], dtype=float32)}"]}, "execution_count": 24, "metadata": {}, "output_type": "execute_result"}], "source": ["train_session.trained_coef_"]}, {"cell_type": "code", "execution_count": 24, "id": "f19f97fe", "metadata": {}, "outputs": [{"data": {"text/plain": ["[0.0036812867, 0.0038135047, 0.0037041684, 0.0037206002, 0.0032002896]"]}, "execution_count": 25, "metadata": {}, "output_type": "execute_result"}], "source": ["train_session.train_losses_[-5:]"]}, {"cell_type": "code", "execution_count": 25, "id": "82cbd6cf", "metadata": {}, "outputs": [{"data": {"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAAfQElEQVR4nO3de5gddZ3n8fe36py+pTudW4eEdCCJBDWGm3QirhpYReWyC7LeyI4DKJBxZkBWfNxhdB901FlHsw4z7qKYcRhEVy4qulnJDLrIGFnBpYEEEi4hXEI6XNLpJJ1O+nYu3/2jqk9O39In6dNp6uTzep5+6FNVfepXXeHTv/Ot36/K3B0REUm+YLIbICIi5aFAFxGpEAp0EZEKoUAXEakQCnQRkQqRmqwdz5o1yxcsWDBZuxcRSaRHH310l7s3jbRu0gJ9wYIFtLa2TtbuRUQSycy2jbZOJRcRkQqhQBcRqRAKdBGRCjFpNXQRkXLIZDK0tbXR29s72U0pq5qaGpqbm0mn0yX/jAJdRBKtra2NhoYGFixYgJlNdnPKwt3p6Oigra2NhQsXlvxzKrmISKL19vYyc+bMiglzADNj5syZh/2pQ4EuIolXSWE+4EiOacxAN7NbzWynmW0aY7tlZpY1s48cdisOw7OvdfGtXz3Lrv19E7kbEZHEKaWHfhtw3qE2MLMQ+AbwqzK06ZCeb9/Pf//NVjr290/0rkRESlJfXz/ZTQBKCHR3Xw/sHmOza4GfATvL0ahDCYPoY0gml5/oXYmIJMq4a+hmNg+4BPhuCduuMrNWM2ttb28/ov2lwyjQc3k9aUlE3ljcnc9//vMsXbqUU045hbvuuguAV199lRUrVnD66aezdOlSfve735HL5bjiiisK2950003j3n85hi3+HfAX7p4fq4jv7muANQAtLS1HlMhhEP0NyubVQxeRwf7qf2/mqVf2lfU9lxw/lS/9+7eVtO0999zDhg0b2LhxI7t27WLZsmWsWLGCH//4x3zwgx/ki1/8Irlcju7ubjZs2MCOHTvYtCm6PLl3795xt7Ucgd4C3BmH+SzgAjPLuvsvyvDew6Tikks2px66iLyxPPjgg6xcuZIwDDnuuOM4++yzeeSRR1i2bBmf+tSnyGQyfOhDH+L0009n0aJFvPDCC1x77bVceOGFfOADHxj3/scd6O5eGPVuZrcBv5yoMIeDga6Si4gMVWpP+mhbsWIF69ev59577+WKK67g+uuv57LLLmPjxo3cd9993HLLLdx9993ceuut49pPKcMW7wAeAt5sZm1mdqWZfdrMPj2uPR+hVFxDzyjQReQN5j3veQ933XUXuVyO9vZ21q9fz/Lly9m2bRvHHXccV199NVdddRWPPfYYu3btIp/P8+EPf5ivfe1rPPbYY+Pe/5g9dHdfWeqbufsV42pNCVJxDT2nGrqIvMFccsklPPTQQ5x22mmYGd/85jeZM2cOP/jBD1i9ejXpdJr6+npuv/12duzYwSc/+UnycZZ9/etfH/f+E3cvl4PDFtVDF5E3hv379wPR7M7Vq1ezevXqQesvv/xyLr/88mE/V45eebHETf1PhwM9dAW6iEixxAW6JhaJiIwscYGuUS4iMpR75eXBkRxT8gI9HuWSVaCLCNGDIDo6Oioq1Afuh15TU3NYP5e4i6IDo1w0sUhEAJqbm2lra+NIbyfyRjXwxKLDkbxAL9zLRTV0EYF0On1YT/WpZMkruWjYoojIiJIX6Bq2KCIyouQF+kAPXSUXEZFBEhvoOZVcREQGSVygFyYWqeQiIjJI4gLdzAgD0ygXEZEhEhfoEJVdNLFIRGSw5Aa6augiIoMkM9DDQMMWRUSGSGagB6a7LYqIDJHMQA9NPXQRkSGSGehBoKn/IiJDlPKQ6FvNbKeZbRpl/R+Z2RNm9qSZ/d7MTit/MwfTsEURkeFK6aHfBpx3iPUvAme7+ynAV4E1ZWjXIaVCDVsUERlqzNvnuvt6M1twiPW/L3r5MHB4N/A9Ahq2KCIyXLlr6FcC/zzaSjNbZWatZtY6npvRp4JAPXQRkSHKFuhm9m+JAv0vRtvG3de4e4u7tzQ1NR3xvqKSi2roIiLFyvLEIjM7Ffg+cL67d5TjPQ8lFWjYoojIUOPuoZvZCcA9wB+7+5bxN2ls0bBF9dBFRIqN2UM3szuAc4BZZtYGfAlIA7j7LcCNwEzgO2YGkHX3lolqMEQlFwW6iMhgpYxyWTnG+quAq8rWohKEgdHdr5KLiEixhM4UVQ1dRGSoZAZ6qGGLIiJDJTPQAyOrGrqIyCDJDHTdD11EZJhkBnpgZDSxSERkkMQGek73chERGSSZgR4aGZVcREQGSWSghxq2KCIyTCIDPRUEGuUiIjJEQgNdD7gQERkqmYGuiUUiIsMkM9A1sUhEZJhkBnpo5B3y6qWLiBQkM9ADA1DZRUSkSCIDPQyiZmvooojIQYkM9HQY9dA1/V9E5KBEBnoYl1w0/V9E5KBEBnoqjJqtGrqIyEHJDPTCRVGVXEREBowZ6GZ2q5ntNLNNo6w3M/u2mW01syfM7O3lb+ZghUBXyUVEpKCUHvptwHmHWH8+sDj+WgV8d/zNOrRUqGGLIiJDjRno7r4e2H2ITS4GbvfIw8A0M5tbrgaOJFUYtqiSi4jIgHLU0OcB24tet8XLhjGzVWbWamat7e3tR7zDgZJLRiUXEZGCo3pR1N3XuHuLu7c0NTUd8fsUhi2q5CIiUlCOQN8BzC963RwvmzBpDVsUERmmHIG+FrgsHu1yFtDp7q+W4X1HFRZGuaiGLiIyIDXWBmZ2B3AOMMvM2oAvAWkAd78FWAdcAGwFuoFPTlRjB2iUi4jIcGMGuruvHGO9A39ethaVYGCUi8ahi4gclMyZoqFmioqIDJXMQNdMURGRYRIZ6KEecCEiMkwiA31g2KLGoYuIHJTIQA91t0URkWESGehpjXIRERkmkYEeapSLiMgwiQz0tC6KiogMk8hADzVsUURkmEQGup4pKiIyXDIDXTfnEhEZJpGBrolFIiLDJTLQNbFIRGS4RAZ63EFXyUVEpEgiA93MSIemkouISJFEBjpEdXQFuojIQYkN9HQQaBy6iEiRxAZ6GJqm/ouIFElsoKdUchERGaSkQDez88zsWTPbamY3jLD+BDN7wMweN7MnzOyC8jd1sFQQkFPJRUSkYMxAN7MQuBk4H1gCrDSzJUM2+y/A3e5+BnAp8J1yN3SoMDAyKrmIiBSU0kNfDmx19xfcvR+4E7h4yDYOTI2/bwReKV8TR5YOTROLRESKlBLo84DtRa/b4mXFvgx8wszagHXAtSO9kZmtMrNWM2ttb28/guYeFAamUS4iIkXKdVF0JXCbuzcDFwA/NLNh7+3ua9y9xd1bmpqaxrXDdBholIuISJFSAn0HML/odXO8rNiVwN0A7v4QUAPMKkcDR6MeuojIYKUE+iPAYjNbaGZVRBc91w7Z5mXgfQBm9laiQB9fTWUMGrYoIjLYmIHu7lngGuA+4Gmi0SybzewrZnZRvNnngKvNbCNwB3CFu09o2qZUchERGSRVykbuvo7oYmfxshuLvn8KeFd5m3ZoKrmIiAyW2JmiGrYoIjJYYgM9DAIyCnQRkYLEBno6MHKqoYuIFCQ20FVDFxEZLLGBHk0sUqCLiAxIbKBHPXSVXEREBiQ20DWxSERksOQGuoYtiogMkthAD4OAjC6KiogUJDbQo4lFqqGLiAxIbKBr2KKIyGCJDXQNWxQRGSyxgR4GprstiogUSWyga9iiiMhgCQ70AHc0dFFEJJbcQA8NQGUXEZFYcgM9iAJdPXQRkUhiAz2MA12Ti0REIokN9HQYNV09dBGRSEmBbmbnmdmzZrbVzG4YZZuPmdlTZrbZzH5c3mYON9BD1x0XRUQiYz4k2sxC4Gbg/UAb8IiZrY0fDD2wzWLgL4F3ufseM5s9UQ0ekC5cFFUPXUQESuuhLwe2uvsL7t4P3AlcPGSbq4Gb3X0PgLvvLG8zhwuDqOma/i8iEikl0OcB24tet8XLip0MnGxm/9fMHjaz80Z6IzNbZWatZtba3t5+ZC2ODYxy0bBFEZFIuS6KpoDFwDnASuAfzGza0I3cfY27t7h7S1NT0/h2GGrYoohIsVICfQcwv+h1c7ysWBuw1t0z7v4isIUo4CdMSsMWRUQGKSXQHwEWm9lCM6sCLgXWDtnmF0S9c8xsFlEJ5oXyNXO4VKBhiyIixcYMdHfPAtcA9wFPA3e7+2Yz+4qZXRRvdh/QYWZPAQ8An3f3jolqNEAYl1wyqqGLiAAlDFsEcPd1wLohy24s+t6B6+OvoyKtHrqIyCCJnSl6cOq/eugiIpDgQNcoFxGRwZIb6IFmioqIFEtwoGumqIhIseQGeqHkohq6iAgkOdA1sUhEZJDkBrruhy4iMkhyA13DFkVEBklsoId6pqiIyCCJDfRUYeq/Al1EBJIc6ANT/1VyEREBkhzoegSdiMggyQ10zRQVERkkwYGuYYsiIsUSHOgatigiUiyxgR4ERmDqoYuIDEhsoENUdtHUfxGRSKIDPQxMN+cSEYklOtBToWmUi4hIrKRAN7PzzOxZM9tqZjccYrsPm5mbWUv5mji6VGC6H7qISGzMQDezELgZOB9YAqw0syUjbNcAXAf8odyNHE0qDNRDFxGJldJDXw5sdfcX3L0fuBO4eITtvgp8A+gtY/sOKeqhq4YuIgKlBfo8YHvR67Z4WYGZvR2Y7+73HuqNzGyVmbWaWWt7e/thN3aoVGgatigiEhv3RVEzC4C/BT431rbuvsbdW9y9pampaby7joYtKtBFRIDSAn0HML/odXO8bEADsBT4VzN7CTgLWHs0Loxq2KKIyEGlBPojwGIzW2hmVcClwNqBle7e6e6z3H2Buy8AHgYucvfWCWlxkVRgmlgkIhIbM9DdPQtcA9wHPA3c7e6bzewrZnbRRDfwUFRDFxE5KFXKRu6+Dlg3ZNmNo2x7zvibVZpUoGGLIiIDkj1TVMMWRUQKkh3omvovIlKQ7EAPAvXQRURiiQ70aNiieugiIpDwQE+HGrYoIjIg0YGuHrqIyEGJDvTobouqoYuIQNIDPdAoFxGRAQkP9EAPuBARiSU80E0lFxGRWLIDXfdyEREpSHag626LIiIFiQ70MAjUQxcRiSU60KOJRaqhi4hAwgNdE4tERA5KdKBHE4scd4W6iEiyAz0wAPXSRURIeqCHUaBrtqiISNIDPVCgi4gMKCnQzew8M3vWzLaa2Q0jrL/ezJ4ysyfM7H4zO7H8TR0uDKLm5zQWXURk7EA3sxC4GTgfWAKsNLMlQzZ7HGhx91OBnwLfLHdDR5KOSy4ZTf8XESmph74c2OruL7h7P3AncHHxBu7+gLt3xy8fBprL28yRhbooKiJSUEqgzwO2F71ui5eN5krgn0daYWarzKzVzFrb29tLb+Uo0nHJRTV0EZEyXxQ1s08ALcDqkda7+xp3b3H3lqampnHvb6CHrgdFi4hAqoRtdgDzi143x8sGMbNzgS8CZ7t7X3mad2gatigiclApPfRHgMVmttDMqoBLgbXFG5jZGcD3gIvcfWf5mzmy1EDJRaNcRETGDnR3zwLXAPcBTwN3u/tmM/uKmV0Ub7YaqAd+YmYbzGztKG9XVgd76Cq5iIiUUnLB3dcB64Ysu7Ho+3PL3K6SFCYWqYcuIpLsmaKhZoqKiBQkOtCnVEcfMDa/0jnJLRERmXyJDvS3nzCdd500k6/d+zSbdijUReTYluhADwPj25eewawpVXz6R4+y50D/ZDdJRGTSJDrQAWbWV/OdT5zJzn19XHfXBp56ZR9te7rZ35ed7KaJiBxVJY1yeaM7ff40vnzR2/jCz59k/ZbolgJmcP25J3Pt+xZPcutERI6Oigh0gP/4jhM4ZV4jO/Z209mT4bdb2vnWr7cAlBTqnT0Z0qFRV1UxvxIROcZUVHqd0tzIKc2NAHzkzPnUpDfyrV9vIQiMT7zjRF7qOEDbnh5OP2Ea86bVFn6u9aXdXHV7KyfOqONnf/pvSIWJr0SJyDGoogK9WBgYqz9yGu6w+r5nWX3fs4V1NemAz7xvMVe9exH3P/061921gak1aTa2dfIPv3uRPz3nTYVtH3xuFy/v7uZjLc0KehF5QzP3yZmU09LS4q2trRO+n1ze+dHD2+jP5jlxZh2zGqr53m+f577Nr9M8vZYde3s4Y/40vn/5Mr5wz5P85tmdrPvMezhpdj3rnnyVa+94nFzeeevcqfzXS5ZyxgnTJ7zNIiKjMbNH3b1lxHWVHuijuf/p1/navU+z5PipfOujp1GTDtnZ1csHblrPollTuOJdC/nsXRs4ff40LnvniXx93TO83tXLB5fMYWHTFOZMrWHx7HrOWjSTIJ6xKiIy0RToh+Hnj7fx2bs2ArBswXT+6ZPLqa9Osb8vy02/3sK/bHqN1/f1Fm43cPJx9axa8SYuOu14qlIqyYjIxFKgHwZ353N3b6SzJ8O3V55RuL1AsXze6TjQz4Nb2/neb1/gmde6mFVfzXvf0sTZJ8/m3Ytn0VibPuQ+frulnUzOmTGliplTqqirDqlJh9SkQv1hEJFRKdAnkLvzr1va+Unrdn733C66erOYwZuPa+DME6ezfOEMLjhlLun4gmo+79y4dhM/evjlUd/zo2c289UPLaUmHR6twxCRhFCgHyXZXJ4N2/fy4NZdPLptDxte3ktXX5Ylc6fyzY+cypK5U/nCz5/kzke28ycrFnHhqXPpONDP7v39dGdy9GVyvNRxgB89/DKnNTdyyx+fydzG2rF3PIbtu7vpy+Y5aXZ9GY5SRCaTAn2S5PLOrza/xo1rN7P7QD+nNjfy+Mt7+cx7T+Kz7z8Zs5Evpv5q82tcf/dGatIBV757ESfOrOOEGXUA7Ozqpb2rj8baNMsXzmTGlKphP/9qZw8Pv9DBQ8938PvnO2jb0wPA3/yHU7h0+QkTd8AiMuEU6JOsszvDX697irtb2/jsuSdz3bljz1zdurOLa378OM+81nXI7d4yp4GFs6aQyeXpy+bZvrublzq6AWisTXPWohm8c9FMHni2nd9uaecbHz6Fjy+LQr1jfx8PvdBBd3+O/myevDuzG6ppnl7HnMYacnmnuz9HbybH8Y21NNaNfl1ARI4OBfobxL7eDFNrDi8U9/Vm2L67m+27uwFj9tRqZjdU8/q+vkIv/LV9vVSFAVWpgNkN1bxj0UzOWjSDt86ZWhhS2ZvJ8Sc/fJT1z7XzZ+e8iS2v7+eBZ3Ye1sNBmhqqOampnmw+T8f+fnZ39/PWOVP52LJmzl86VzV/kaNAgS5AFOqrfvgo67e009RQzSVnzOPCU+YyY0oV1akAM+P1fb207enhtc4e0qmAuqqQqjCkbU83z+3cz/Pt+6kKA2Y1VNNYmy7MpG2oSXFqcyPTaquYVpempz/H9j3dhXLP/Bl1nDijjpp0yCt7e9ixt4d9PRlSYUAqNAIzsrk8/dk8QWDMn17Hgll1zJ9Rx8wpVUyvq6KhJk1vNseBviy9mTyzG6qZN72W4xtr6c/m6ezJ0NWXYcHMKSOOThpLLu/k3QsXsEXeiMYd6GZ2HvD3QAh8393/Zsj6auB24EygA/i4u790qPdUoE+OvmyOza/s49R5jWW5lUE+7/zhxd387LE2Xtx1gD3d/eztzlCTCmieUUfz9Oii7ssd3Wzb3U1/Ns/x02qZN62Gxtoqcvk8mbzjcZCmw4BsLs/Lu7vZ1tFNxxHc4z4VGKfPn8Y73zSTXN7Z1tHNtt0H6O7LMfCvfWptmvnTa2meXkd3f5ZNOzp56tV95PPw5jkNLJ3XyFvmNDBvWi3HT6ulqaGauqpoaGk2n2fHnh627+lhV1cfU6pDGmrSNNSkiv6boioMCtdJ3J2uviyd3Rnco7uBAuzp7ue1zl52dvXR1FDNqc2NzJlaM+r1FZFxBbqZhcAW4P1AG/AIsNLdnyra5s+AU93902Z2KXCJu3/8UO+rQJdSdPdn2X2gnz0HMnT1ZqipCqmvjsJyZ1cfbXu6ebWzl+pUQGNtmrqqFJte6eT3z3fwZNteAjPmz4guKk8tmhuw50A/bXu62bG3h6ow4G3HN7J0XiPp0Nj0SidPtnWyr3d899QPDGrSIekwYH9fllyJ5a1Z9dXUV4d09WbZ15sh71CdCqhJh9SmQxpqUkypTjFjShULZtZx4swpzJxSxYH+6NPLgf4s3X05DvRn6c/mqUmHhT9GubwXvvrjT0SZXJ7adEh9TYr66hSpwAjDgNCMbD5PXyZPXzZHOgyoq04xpSqkOhWSCo2quFOQyeXJ5p3eTI79fVm6eqPjrUkH1KZD6qpSTK1N01ibpiYd0JvJ05vJ0ZOJrtH09Ofoz+VJh0GhfGgGhmEW3ZspHRqpIFoeWLS8OhUdW206ak8YRJ/2+jJ59se/i1RgTK1NM7Um2veAdBgUPpmWyj26rnSgP0tDdfR+R/uP76ECvZTPpcuBre7+QvxmdwIXA08VbXMx8OX4+58C/8PMzCerniMVo64qRV1ViuYRbqGzYNYUli+cMWz5hafOBaI/BlVhcMhPIrm8YzDs9g3uTntXH6909vLK3h527e+LgycPQPP0WubPqGN2QzXd/Tm6ejN09Wbp6suwrydLV2+mEFr9uTz11Smm11XRWJsmDIy8Ow5Mr6viuKnVzG6o4dXOHp5o6+TJHZ30ZfNMjXv8YcDBAOzPFYJqW8cB1m9ppy+bH3ZcgUXP3K0KA3ri4Cz+vzEMrBCcqcDoyeTo7s+NfUIqjBnUpkPCwMjlnWwu+iVVpwKq0yFVoeFA3qN1+3ozZHIHf5FVYUBDTQozcKfwCdAY+BR28N9VKjDSKSMdBqxcdgJXr1hU9uMpJdDnAduLXrcB7xhtG3fPmlknMBPYVbyRma0CVgGccIKGz8nEKuXe9uEo9+ExM2ZPrWH21BpOnz+tzC0b2ZzGmsO++Vs+77ze1cueAxnqq1NMqQ6ZUp0a1vN0j3rkqSAgMEbsVebyzoH+LNncwV58OrQ42AKy+ajX292XG9S7NzNSgZEKjepU9AlioKffm83TE39y2NebobMnQ18m+tRQk44/dcQ97HQYkCl634FwHAjTbN7Jxsvdo/b2ZaM/ct39ucI1kJw71amQ+vh3MRDEnT0Z+ov++PXn8vT2R3/sMrnoWMMg+uPfl83Rl42v6cSfFMLQaIw/ZdRVhezvy9LZk2F/bxaHwnaODwp34vbm8x4dXy5PU0P1YZ3nUh3V2+e6+xpgDUQll6O5b5FKFATG3MbaMSegmUVheyhhYIcchVVFEP2RbCi9ffVhQH11asICTAYr5arYDmB+0evmeNmI25hZCmgkujgqIiJHSSmB/giw2MwWmlkVcCmwdsg2a4HL4+8/AvxG9XMRkaNrzJJLXBO/BriPaNjire6+2cy+ArS6+1rgH4EfmtlWYDdR6IuIyFFUUg3d3dcB64Ysu7Ho+17go+VtmoiIHA5NiRMRqRAKdBGRCqFAFxGpEAp0EZEKMWl3WzSzdmDbEf74LIbMQj1GHIvHfSweMxybx30sHjMc/nGf6O5NI62YtEAfDzNrHe3mNJXsWDzuY/GY4dg87mPxmKG8x62Si4hIhVCgi4hUiKQG+prJbsAkORaP+1g8Zjg2j/tYPGYo43EnsoYuIiLDJbWHLiIiQyjQRUQqROIC3czOM7NnzWyrmd0w2e2ZCGY238weMLOnzGyzmV0XL59hZr82s+fi/x7e420SwsxCM3vczH4Zv15oZn+Iz/ld8W2cK4aZTTOzn5rZM2b2tJm981g412b22fjf9yYzu8PMairxXJvZrWa208w2FS0b8fxa5Nvx8T9hZm8/nH0lKtDjB1bfDJwPLAFWmtmSyW3VhMgCn3P3JcBZwJ/Hx3kDcL+7Lwbuj19XouuAp4tefwO4yd1PAvYAV05KqybO3wP/4u5vAU4jOvaKPtdmNg/4DNDi7kuJbs19KZV5rm8DzhuybLTzez6wOP5aBXz3cHaUqECn6IHV7t4PDDywuqK4+6vu/lj8fRfR/+DziI71B/FmPwA+NCkNnEBm1gxcCHw/fm3Ae4kePg4Vdtxm1gisIHqmAO7e7+57OQbONdHtu2vjp5zVAa9Sgefa3dcTPSei2Gjn92Lgdo88DEwzs7ml7itpgT7SA6vnTVJbjgozWwCcAfwBOM7dX41XvQYcN1ntmkB/B/xnYOBpvjOBve6ejV9X2jlfCLQD/xSXmb5vZlOo8HPt7juA/wa8TBTkncCjVPa5Ljba+R1XxiUt0I8pZlYP/Az4T+6+r3hd/Ii/ihpzamb/Dtjp7o9OdluOohTwduC77n4GcIAh5ZUKPdfTiXqjC4HjgSkML0scE8p5fpMW6KU8sLoimFmaKMz/p7vfEy9+feDjV/zfnZPVvgnyLuAiM3uJqJz2XqL68rT4YzlU3jlvA9rc/Q/x658SBXyln+tzgRfdvd3dM8A9ROe/ks91sdHO77gyLmmBXsoDqxMvrhv/I/C0u/9t0arih3FfDvyvo922ieTuf+nuze6+gOjc/sbd/wh4gOjh41Bhx+3urwHbzezN8aL3AU9R4eeaqNRylpnVxf/eB467Ys/1EKOd37XAZfFol7OAzqLSzNjcPVFfwAXAFuB54IuT3Z4JOsZ3E30EewLYEH9dQFRPvh94Dvg/wIzJbusE/g7OAX4Zf78I+H/AVuAnQPVkt6/Mx3o60Bqf718A04+Fcw38FfAMsAn4IVBdiecauIPoOkGG6BPZlaOdX8CIRvI9DzxJNAqo5H1p6r+ISIVIWslFRERGoUAXEakQCnQRkQqhQBcRqRAKdBGRCqFAFxGpEAp0EZEK8f8BO6d5zUN015gAAAAASUVORK5CYII=\n", "text/plain": ["
"]}, "metadata": {"needs_background": "light"}, "output_type": "display_data"}], "source": ["import pandas\n", "\n", "pandas.DataFrame({'loss': train_session.train_losses_}).plot();"]}, {"cell_type": "markdown", "id": "4d3ab358", "metadata": {}, "source": ["## Fordward backward: TrainingAgent\n", "\n", "This second implementation uses [TrainingAgent](http://www.xavierdupre.fr/app/onnxcustom/helpsphinx/api/onnxruntime_python/training_partial.html#trainingagent)."]}, {"cell_type": "code", "execution_count": 26, "id": "809cb2f3", "metadata": {}, "outputs": [], "source": ["from onnxcustom.training.optimizers_partial import OrtGradientForwardBackwardOptimizer\n", "\n", "train_session = OrtGradientForwardBackwardOptimizer(\n", " onx, ['init', 'init_1', 'init_b10'], learning_rate=1e-1, \n", " batch_size=2, max_iter=100)"]}, {"cell_type": "code", "execution_count": 27, "id": "e47bfdf6", "metadata": {}, "outputs": [{"data": {"text/plain": ["OrtGradientForwardBackwardOptimizer(model_onnx='ir_version...', weights_to_train=['init', 'init_1', 'init_b10'], loss_output_name='loss', max_iter=100, training_optimizer_name='SGDOptimizer', batch_size=2, learning_rate=LearningRateSGD(eta0=0.1, alpha=0.0001, power_t=0.25, learning_rate='invscaling'), value=0.03162277660168379, device='cpu', warm_start=False, verbose=0, validation_every=10, learning_loss=SquareLearningLoss(), enable_logging=False, weight_name=None, learning_penalty=NoLearningPenalty(), exc=True)"]}, "execution_count": 28, "metadata": {}, "output_type": "execute_result"}], "source": ["train_session.fit(X, y)"]}, {"cell_type": "code", "execution_count": 28, "id": "8f34bf7a", "metadata": {}, "outputs": [{"data": {"text/plain": ["[0.00040441833, 0.00037421435, 0.00049950054, 0.00042527347, 0.00031072882]"]}, "execution_count": 29, "metadata": {}, "output_type": "execute_result"}], "source": ["train_session.train_losses_[-5:]"]}, {"cell_type": "code", "execution_count": 29, "id": "ba57e16f", "metadata": {}, "outputs": [{"data": {"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAAgF0lEQVR4nO3dfZBddZ3n8ff3nHO7O92d57RB8kASCY4BBLQJOq7BUZQ4zhJctQw7MwYLzbJL1F123cFiBpVxSge2dJwqVklpXLQGAZGaiZoxyyCKriJpIASSGAjhIR2C6XSe00/34bt/nHO7bz8lt9PddPzl86p05Z6ne3+nT/K5v/M9T+buiIhIuKKJboCIiIwvBb2ISOAU9CIigVPQi4gETkEvIhK4ZKIbMNCsWbN8wYIFE90MEZE/KI8//vh+d28aatppF/QLFiygpaVlopshIvIHxcxeGm6aSjciIoFT0IuIBE5BLyISuKpq9Ga2HPg6EAPfcvevDJh+PXADUASOAavdfZuZLQC2AzuyWR919+vHqO0iIsPK5/O0trbS1dU10U0ZU3V1dcydO5dcLlf1MicNejOLgTuA9wKtwCYzW+/u2ypmu9vdv5nNfxXwVWB5Nu15d7+46haJiIyB1tZWJk+ezIIFCzCziW7OmHB32tvbaW1tZeHChVUvV03pZimw0913uXsPcA+wYsCHH6kYbAB0pzQRmVBdXV3MnDkzmJAHMDNmzpw54r2UaoJ+DrC7Yrg1GzewATeY2fPAbcCnKyYtNLMnzewXZvbOoT7AzFabWYuZtbS1tY2g+SIiwwsp5MtOZZ3G7GCsu9/h7m8A/gr462z0XmC+u18C3AjcbWZThlh2rbs3u3tzU9OQ5/uf1PHuAl/9vzt48uWDp7gGIiJhqibo9wDzKobnZuOGcw9wNYC7d7t7e/b6ceB54LxTaulJdBdK/OPPdrKl9fB4vL2IyIg1NjZOdBOA6oJ+E7DYzBaaWQ2wElhfOYOZLa4Y/ADwXDa+KTuYi5ktAhYDu8ai4QPFUbo7Uyjp8ICISKWTBr27F4A1wEbSUyXvc/etZnZrdoYNwBoz22pmm0lLNKuy8cuALdn4+4Hr3f3AGK8DAEkW9MVSaTzeXkTklLk7n/3sZ7ngggu48MILuffeewHYu3cvy5Yt4+KLL+aCCy7gl7/8JcVikWuvvbZ33q997Wuj/vyqzqN39w3AhgHjbql4/Zlhlvsh8MPRNLBa6tGLyHC++KOtbHvlyMlnHIElZ0/h8//+/KrmfeCBB9i8eTNPPfUU+/fv59JLL2XZsmXcfffdXHnlldx8880Ui0U6OjrYvHkze/bs4ZlnngHg0KFDo25rMFfGlnv0haKCXkROL7/61a+45ppriOOY2bNnc/nll7Np0yYuvfRSvvOd7/CFL3yBp59+msmTJ7No0SJ27drFpz71KX76058yZcqg81dG7LS7e+WpUo9eRIZTbc/7tbZs2TIeeeQRfvKTn3Dttddy44038rGPfYynnnqKjRs38s1vfpP77ruPdevWjepzgunRmxlJZKrRi8hp553vfCf33nsvxWKRtrY2HnnkEZYuXcpLL73E7Nmz+eQnP8knPvEJnnjiCfbv30+pVOJDH/oQX/rSl3jiiSdG/fnB9Ogh7dWrRy8ip5sPfvCD/OY3v+Giiy7CzLjttts466yzuOuuu7j99tvJ5XI0Njby3e9+lz179vDxj3+cUtZp/fKXvzzqzw8q6JPIKKpGLyKniWPHjgFpxeH222/n9ttv7zd91apVrFq1atByY9GLrxRM6QbUoxcRGUpQQZ/EEQXV6EVE+gkq6OPIKKpHLyIZ9/Dy4FTWKaigz0Wm8+hFBEgf0NHe3h5U2JfvR19XVzei5YI6GBvH6tGLSGru3Lm0trYS2q3Py0+YGomggj6JIh2MFREAcrnciJ7CFLKgSjeq0YuIDBZU0CeRkS/qrBsRkUpBBb169CIigwUV9Ol59Ap6EZFKYQW9evQiIoMEFfTpLRBUoxcRqRRU0KtHLyIyWFBBH0dGXlfGioj0U1XQm9lyM9thZjvN7KYhpl9vZk+b2WYz+5WZLamY9rlsuR1mduVYNn4g9ehFRAY7adCbWQzcAbwfWAJcUxnkmbvd/UJ3vxi4DfhqtuwSYCVwPrAc+N/Z+40LnXUjIjJYNT36pcBOd9/l7j3APcCKyhncvfLx6g1AOW1XAPe4e7e7vwDszN5vXOhRgiIig1Vzr5s5wO6K4VbgsoEzmdkNwI1ADfDuimUfHbDsnCGWXQ2sBpg/f3417R6SHjwiIjLYmB2Mdfc73P0NwF8Bfz3CZde6e7O7Nzc1NZ1yG1SjFxEZrJqg3wPMqxiem40bzj3A1ae47KjEUaT70YuIDFBN0G8CFpvZQjOrIT24ur5yBjNbXDH4AeC57PV6YKWZ1ZrZQmAx8Njomz20RBdMiYgMctIavbsXzGwNsBGIgXXuvtXMbgVa3H09sMbMrgDywEFgVbbsVjO7D9gGFIAb3L04TutCogePiIgMUtWDR9x9A7BhwLhbKl5/5gTL/h3wd6fawJFIdDBWRGSQwK6MjSiqRi8i0k9QQZ/E6tGLiAwUVNDr7pUiIoMFFfSq0YuIDBZY0Ee4Q0lhLyLSK6ygjw1AvXoRkQpBBX0cpUGvc+lFRPoEFfRJVO7R64CsiEhZUEFf7tHrfjciIn2CCvq+Hr2CXkSkLKygj9PVUY1eRKRPUEEfq0YvIjJIUEGf6KwbEZFBggr6WDV6EZFBggr6JEpXR2fdiIj0CSroVaMXERksqKDPxarRi4gMFFTQq0YvIjJYUEFfrtGrRy8i0ieooNctEEREBqsq6M1suZntMLOdZnbTENNvNLNtZrbFzB4ys3MqphXNbHP2s34sGz9Q322KdTBWRKQsOdkMZhYDdwDvBVqBTWa23t23Vcz2JNDs7h1m9p+B24CPZtM63f3isW320FSjFxEZrJoe/VJgp7vvcvce4B5gReUM7v6wu3dkg48Cc8e2mdXJlWv0Kt2IiPSqJujnALsrhluzccO5DvjXiuE6M2sxs0fN7OqhFjCz1dk8LW1tbVU0aWjq0YuIDHbS0s1ImNlfAM3A5RWjz3H3PWa2CPiZmT3t7s9XLufua4G1AM3Nzaec0onOoxcRGaSaHv0eYF7F8NxsXD9mdgVwM3CVu3eXx7v7nuzvXcDPgUtG0d4T0pWxIiKDVRP0m4DFZrbQzGqAlUC/s2fM7BLgTtKQ31cxfrqZ1WavZwHvACoP4o6pRKdXiogMctLSjbsXzGwNsBGIgXXuvtXMbgVa3H09cDvQCPzAzABedvergDcBd5pZifRL5SsDztYZU3o4uIjIYFXV6N19A7BhwLhbKl5fMcxyvwYuHE0DR6L37pUKehGRXkFdGdt3MFY1ehGRsrCCXqdXiogMElTQq0YvIjJYUEFfrtHnddaNiEivoIK+r0evGr2ISFlQQa8avYjIYEEFfRQZkalGLyJSKaigh7ROrx69iEif4II+jkw9ehGRCsEFfRIZ+aIOxoqIlAUX9HGsHr2ISKXggj6JTDV6EZEKAQZ9pEcJiohUCC7oY/XoRUT6CS7ok9h0ZayISIXggj6OjLx69CIivYIL+iQy1ehFRCoEF/SxrowVEeknuKDPqUYvItJPVUFvZsvNbIeZ7TSzm4aYfqOZbTOzLWb2kJmdUzFtlZk9l/2sGsvGD0Vn3YiI9HfSoDezGLgDeD+wBLjGzJYMmO1JoNnd3wzcD9yWLTsD+DxwGbAU+LyZTR+75g+W6F43IiL9VNOjXwrsdPdd7t4D3AOsqJzB3R92945s8FFgbvb6SuBBdz/g7geBB4HlY9P0ocWRUdDBWBGRXtUE/Rxgd8VwazZuONcB/3qKy45aepti1ehFRMqSsXwzM/sLoBm4fITLrQZWA8yfP39UbdBtikVE+qumR78HmFcxPDcb14+ZXQHcDFzl7t0jWdbd17p7s7s3NzU1Vdv2IeViHYwVEalUTdBvAhab2UIzqwFWAusrZzCzS4A7SUN+X8WkjcD7zGx6dhD2fdm4caMevYhIfyct3bh7wczWkAZ0DKxz961mdivQ4u7rgduBRuAHZgbwsrtf5e4HzOxvSb8sAG519wPjsiYZPUpQRKS/qmr07r4B2DBg3C0Vr684wbLrgHWn2sCRSs+60cFYEZGy4K6M1YNHRET6Cy7oVaMXEekvuKBPYtXoRUQqhRf06tGLiPQTXNDHkZHXwVgRkV7BBb169CIi/QUX9LGujBUR6Se4oFePXkSkvwCDPqJYctwV9iIiEGTQG4B69SIimeCCPo7ToFedXkQkFVzQl3v0CnoRkVRwQR9H6SoV9ThBEREgwKDv69HroikREQgx6GMdjBURqRRe0KtGLyLST3BBX67RF1SjFxEBAgx61ehFRPoLLuhjXTAlItJPcEGvGr2ISH9VBb2ZLTezHWa208xuGmL6MjN7wswKZvbhAdOKZrY5+1k/Vg0fThJn59Er6EVEAEhONoOZxcAdwHuBVmCTma13920Vs70MXAv8jyHeotPdLx59U6ujHr2ISH8nDXpgKbDT3XcBmNk9wAqgN+jd/cVs2oQfAS3X6At6ypSICFBd6WYOsLtiuDUbV606M2sxs0fN7OqhZjCz1dk8LW1tbSN468HUoxcR6e+1OBh7jrs3A/8R+Acze8PAGdx9rbs3u3tzU1PTqD5MZ92IiPRXTdDvAeZVDM/NxlXF3fdkf+8Cfg5cMoL2jVii2xSLiPRTTdBvAhab2UIzqwFWAlWdPWNm082sNns9C3gHFbX98ZCU716pC6ZERIAqgt7dC8AaYCOwHbjP3bea2a1mdhWAmV1qZq3AR4A7zWxrtvibgBYzewp4GPjKgLN1xlzfwVj16EVEoLqzbnD3DcCGAeNuqXi9ibSkM3C5XwMXjrKNI6LSjYhIf7oyVkQkcMEFfawavYhIP8EFfaIavYhIP+EFvZ4wJSLST3BBH6tGLyLST3BBn/Q+YUo1ehERCDDo1aMXEekvuKBPdK8bEZF+ggt69ehFRPoLLuhzesKUiEg/wQV91qFXj15EJBNc0JsZSWQ660ZEJBNc0ENap1fpRkQkFWTQJ5GpdCMikgky6NWjFxHpE2TQ5+KIgu5eKSICBBr06tGLiPQJMuiTyMjrNsUiIkCgQR/H6tGLiJQFGfRJFOmsGxGRTFVBb2bLzWyHme00s5uGmL7MzJ4ws4KZfXjAtFVm9lz2s2qsGn4iaY1eB2NFRKCKoDezGLgDeD+wBLjGzJYMmO1l4Frg7gHLzgA+D1wGLAU+b2bTR9/sE0uvjFWPXkQEquvRLwV2uvsud+8B7gFWVM7g7i+6+xZgYDf6SuBBdz/g7geBB4HlY9DuE0pUoxcR6VVN0M8BdlcMt2bjqlHVsma22sxazKylra2tyrceXhxF5BX0IiLAaXIw1t3Xunuzuzc3NTWN+v0S1ehFRHpVE/R7gHkVw3OzcdUYzbKnLFaNXkSkVzVBvwlYbGYLzawGWAmsr/L9NwLvM7Pp2UHY92XjxlWiK2NFRHqdNOjdvQCsIQ3o7cB97r7VzG41s6sAzOxSM2sFPgLcaWZbs2UPAH9L+mWxCbg1Gzeukljn0YuIlCXVzOTuG4ANA8bdUvF6E2lZZqhl1wHrRtHGEVOPXkSkz2lxMHasxZGR1xOmRESAQINePXoRkT5BBr1uUywi0ifIoNejBEVE+oQZ9HGkHr2ISCbMoI9MjxIUEckEGfS6MlZEpE+QQa8avYhInyCDPo5UoxcRKQsy6JNYNXoRkbIwg17n0YuI9Ao26FWjFxFJBRn0cRThjnr1IiIEGvRJbACq04uIEGjQx1Ea9OrRi4gEGvRJVO7RK+hFRIIO+qKujhURCTPo4zhdLfXoRUQCDfq+0o0OxoqIVBX0ZrbczHaY2U4zu2mI6bVmdm82/bdmtiAbv8DMOs1sc/bzzTFu/5DKB2N1YzMRkSoeDm5mMXAH8F6gFdhkZuvdfVvFbNcBB939XDNbCfw98NFs2vPufvHYNvvEEp11IyLSq5oe/VJgp7vvcvce4B5gxYB5VgB3Za/vB95jZjZ2zRyZWGfdiIj0qibo5wC7K4Zbs3FDzuPuBeAwMDObttDMnjSzX5jZO4f6ADNbbWYtZtbS1tY2ohUYSi47GKsevYjI+B+M3QvMd/dLgBuBu81sysCZ3H2tuze7e3NTU9OoPzTWwVgRkV7VBP0eYF7F8Nxs3JDzmFkCTAXa3b3b3dsB3P1x4HngvNE2+mQSHYwVEelVTdBvAhab2UIzqwFWAusHzLMeWJW9/jDwM3d3M2vKDuZiZouAxcCusWn68FSjFxHpc9Kzbty9YGZrgI1ADKxz961mdivQ4u7rgW8D3zOzncAB0i8DgGXArWaWB0rA9e5+YDxWpFISqUYvIlJ20qAHcPcNwIYB426peN0FfGSI5X4I/HCUbRwx1ehFRPoEeWVsLtZ59CIiZUEGvWr0IiJ9ggz6co1eZ92IiAQa9H0PHlGNXkQkyKDve5SgevQiIkEGvR4lKCLSJ8igz6lGLyLSK8igj2OdRy8iUhZk0Ovh4CIifYIMetXoRUT6BBn0unuliEifIINePXoRkT5BBn35CVOq0YuIBBr05R79Yy+08/Dv9rHvaNcEt0hEZOJUdZviPzRJZCxdOIOHd7Tx8I70GbTvemMTX/4PF/L6qZMmuHUiIq8tcz+9yhvNzc3e0tIyJu91tCvP9r1HeXRXO9/4+fMkkfE3f7aEJWdP4aHt+/j5s/t4Q1MjX7zqfBpqg/zOE5EzhJk97u7NQ04LOegrvdR+nM/ev4XHXkgfcGUG5589hW2vHOG82ZNZ+5fNzJ9ZD0D7sW4AZjbWjnk7RETGg4I+Uyo5P9ryCt2FEn/yxtfRNLmWXz7Xxpq7n8QMVlx0No+9eJDte49Qk0Tc8K5zuf5di6hNYrryRX68ZS9bXznMJfOn8/ZFM2marC8CETk9KOhP4sX9x/lP33ucF/Yfp3nBdN5x7iy27z3Cj7fsZdGsBt7zptfxwBN7aD/eQxJZ79k882fUM7OxhqmTctTXxBzuzHOoI8+x7gL1NQlTJyXMaKjhyvPPYvkFZ1GbxP0+tytfZFfbcV7Yf5w3z53KvBn1r+l6i0g4FPRVcHd6iqV+YfyLZ9v4m39+ht0HO3jPH83m2j9ewGWLZrDtlSP8Zlc7T+85zJHOPIc783T0FJlSlzCtvobG2oSOngJHOgu0HuzglcNdzGio4eqL5wDwYnsa7i+1H6d8BmguNv78snNY8+5zyUURP9ryCv+yeQ/1NQkrL53HFUtm9542WqlQLFEoOXW5eNA0ETlzjDrozWw58HUgBr7l7l8ZML0W+C7wVqAd+Ki7v5hN+xxwHVAEPu3uG0/0WRMV9MPpKZTo6Ckwrb7mlJYvlZz/9/x+/unRl3lw+++piSMWzGpg4ax6zm1qZPHsycyZPokftOzmvpZWapOIQsnpKZQ4b3Yjx7oKvHK4i1mNtbz1nGnki+m0I115Xj3cxf5j3ZQcGmpiZjbWMqMh3cOYVp9jen0N82bUM2/6JM6eNonaJCKJI3oKJZ5qPcTjLx5k+6tHmJSLmV5fw4zGGi6eN40/fsNM5k5P9y6OdRd4qf04L7d38GJ7By8f6KBpci2XLpjOW+ZPH3QQ29050lmg6M70+hxmNuptICInN6qgN7MYeBZ4L9AKbAKucfdtFfP8F+DN7n69ma0EPujuHzWzJcD3gaXA2cC/Aee5e3G4zzvdgn4sdeWL1CbRsOG3c98x7vzF8zTUJnz4rXM5/+wplBx+8ew+vv/Ybl5u76AmicjFRmNdjrOm1DJ7Sh11uZgDx3toP9ZN+/EejnTmOdSZp+1oNx09w/6qmTopxwVzppAvOAc7eth3tJvDnXkA5kybRHehyP5jPf2WmV6f43BnnpKn1yu8bnItDbUJDbUJnT0F9hzs5Hj2mTVxRNPkWmZPqaVpcvozuS5HoVgiX3TcnamTckyZlKOxNiFfLNFdSPdQkshIIqPosHPfUba9coRdbcepSSKmTsoxeVKOXGSYgZlRXxMzuS7H5LqEUsk51l2go6dIoeQY6cH3SbmYafU5pk6qyb4Ic0yrryEXG0c6CxzpytPZU6T8PyI2o6E2obEuoaEmpjaJqc1F1CYR9TUxdbmYmjjiYEee9mPdHOrMM7kuYXp9+mXblS9yuDPP0a4CUWTUJumy5es8BorMSGIjiSIiAwfc0y/PQsl7r/Suy0XUJjFJbBRLTqkEnrXaMKIoXdf6moS63PD/3iDtiPQUS/QUSxSKTr5Ywix9HGccGTVxRM2ANpf3frvyJbrzRcyMyXXJCf9tlxWyz6pN4mF/D+XPqLaT4O505Utpe5MgLw2qymiD/u3AF9z9ymz4cwDu/uWKeTZm8/zGzBLgVaAJuKly3sr5hvu8kIP+tebuHOzIs/tAB68e6SJfLKX/kTEumDOFRbMaiQb8B37298f49fP7aXnpIJNrE+bPrOecGQ0smFXPOTMbaKxNONqV54mXD9Hy4gH2Hu7ieHeBY90F6nIxc6ZNYs60SSSx8fsj3ew70sXvj3ax/2gPbce6OdqVJxdH5OIId+dod4GT7VROq8+x5PVTWPy6Rgol51BnniOdeYolxx1K7nTmixzJQjWJjYaahPramDiKwNMY7Owpcqgzz6GOHvJn0H2Q4siILP0SyP5gBvmiV32bkMjSL6Ki+7DbK8m+zErZNjFLr1KvyUqOx3sKdOX7bh1eE0fZlwO9oZ4vlujJvuwjI+vYRL3buVhy4siIs45AT6FER77Y26ZcbEzKxZil0/LFElFkNNYmNNTGxGZ05ot09hTJF9M2Gum69f1u0g5EWXkc0NtJKXpfZ6T8RRibEUVGqeTkK76Yy8sXS+myPcUSZB2lJDLi2NL/E5Fx4dyp3PmXQ2b1SZ0o6Ks5eXwOsLtiuBW4bLh53L1gZoeBmdn4RwcsO2eIBq4GVgPMnz+/iiZJNcyMGQ01zGio4aIq53/jWZN541mT+fg7Fg473+S6HJef18Tl5zWNuo2lUhr2x7sLaShkeyyFklPIev0zGmrGtATk7nT0FDnY0cOhjjyFkjOlLmHKpFwWEmkoFkoljncXOdad51h3kZ5Cie5Cka58ic58ka6eIt3FEtPrc8xsqGXqpBzHewocON7D4Y48dTVxuvdRl+AO3YUi3fkSpSyV3OkXKMVS2nPPF/s/RyEyIxcbcRRR8rR015VP91bK4RIZvWFXLKVffB09RTrzRdw9C8ms55/+IRcbNXG6Z1CbBWoSG+7pe+SzUOsplOgpFnvbElkabrVJRF0uxt05lv2euvNpsFrWnnLnouSkYVuTUJuL6M6X6CoU6cpCutzhzMURtbmIJIoolvc2CuleRmWQlvdwapKIhpqYSTUJJXeOZ3ty7t77JVEsOcd7ChzvLlIsee/eWC5bVyf9Ein//io7v717VdkeUxJFvduiMrjLbSqVnCgqby/DsPS9Sb8Iy79jIw3+8nrks72pudPH54LO0+IqIXdfC6yFtEc/wc2R11AUGVMn5Zg6KfeafaZlJZmG2oS50080Z1oOgrrXqGUi46OagtYeYF7F8Nxs3JDzZKWbqaQHZatZVkRExlE1Qb8JWGxmC82sBlgJrB8wz3pgVfb6w8DPPN3/WQ+sNLNaM1sILAYeG5umi4hINU5auslq7muAjaSnV65z961mdivQ4u7rgW8D3zOzncAB0i8DsvnuA7YBBeCGE51xIyIiY08XTImIBOBEZ92cuSedioicIRT0IiKBU9CLiAROQS8iErjT7mCsmbUBL43iLWYB+8eoOX8ozsR1hjNzvc/EdYYzc71Hus7nuPuQl6ufdkE/WmbWMtyR51CdiesMZ+Z6n4nrDGfmeo/lOqt0IyISOAW9iEjgQgz6tRPdgAlwJq4znJnrfSauM5yZ6z1m6xxcjV5ERPoLsUcvIiIVFPQiIoELJujNbLmZ7TCznWZ200S3Z7yY2Twze9jMtpnZVjP7TDZ+hpk9aGbPZX+f8JEaf4jMLDazJ83sx9nwQjP7bbbN781uox0UM5tmZveb2e/MbLuZvT30bW1m/y37t/2MmX3fzOpC3NZmts7M9pnZMxXjhty2lvrHbP23mNlbRvJZQQR99gDzO4D3A0uAa7IHk4eoAPx3d18CvA24IVvXm4CH3H0x8FA2HJrPANsrhv8e+Jq7nwscBK6bkFaNr68DP3X3PwIuIl3/YLe1mc0BPg00u/sFpLdGX0mY2/r/AMsHjBtu276f9Hkei0kfu/qNkXxQEEEPLAV2uvsud+8B7gFWTHCbxoW773X3J7LXR0n/488hXd+7stnuAq6ekAaOEzObC3wA+FY2bMC7gfuzWUJc56nAMtLnPeDuPe5+iMC3NelzMiZlT6urB/YS4LZ290dIn99RabhtuwL4rqceBaaZ2eur/axQgn6oB5gPegh5aMxsAXAJ8FtgtrvvzSa9CsyeqHaNk38A/idQfnL2TOCQuxey4RC3+UKgDfhOVrL6lpk1EPC2dvc9wP8CXiYN+MPA44S/rcuG27ajyrhQgv6MY2aNwA+B/+ruRyqnZY9xDOa8WTP7M2Cfuz8+0W15jSXAW4BvuPslwHEGlGkC3NbTSXuvC4GzgQYGlzfOCGO5bUMJ+jPqIeRmliMN+X9y9wey0b8v78plf++bqPaNg3cAV5nZi6RluXeT1q6nZbv3EOY2bwVa3f232fD9pMEf8ra+AnjB3dvcPQ88QLr9Q9/WZcNt21FlXChBX80DzIOQ1aa/DWx3969WTKp8QPsq4F9e67aNF3f/nLvPdfcFpNv2Z+7+58DDpA+jh8DWGcDdXwV2m9kbs1HvIX3+crDbmrRk8zYzq8/+rZfXOehtXWG4bbse+Fh29s3bgMMVJZ6Tc/cgfoA/BZ4Fngdunuj2jON6/jvS3bktwObs509Ja9YPAc8B/wbMmOi2jtP6vwv4cfZ6EfAYsBP4AVA70e0bh/W9GGjJtvc/A9ND39bAF4HfAc8A3wNqQ9zWwPdJj0PkSfferhtu2wJGembh88DTpGclVf1ZugWCiEjgQindiIjIMBT0IiKBU9CLiAROQS8iEjgFvYhI4BT0IiKBU9CLiATu/wNB6HQYWuFoJwAAAABJRU5ErkJggg==\n", "text/plain": ["
"]}, "metadata": {"needs_background": "light"}, "output_type": "display_data"}], "source": ["pandas.DataFrame({'loss': train_session.train_losses_}).plot();"]}, {"cell_type": "code", "execution_count": 30, "id": "5da877fd", "metadata": {}, "outputs": [{"data": {"text/plain": ["{'init': ,\n", " 'init_1': ,\n", " 'init_b10': }"]}, "execution_count": 31, "metadata": {}, "output_type": "execute_result"}], "source": ["train_session.trained_coef_"]}, {"cell_type": "code", "execution_count": 31, "id": "fffb703d", "metadata": {}, "outputs": [{"data": {"text/plain": ["{'init': array([[-0.35357383, 0.6850407 ]], dtype=float32),\n", " 'init_1': array([[-1.916494 , 2.8799832]], dtype=float32),\n", " 'init_b10': array([-1.0036615], dtype=float32)}"]}, "execution_count": 32, "metadata": {}, "output_type": "execute_result"}], "source": ["{k: v.numpy() for k, v in train_session.trained_coef_.items()}"]}, {"cell_type": "markdown", "id": "8caa3421", "metadata": {}, "source": ["Not the same weights? What about the prediction?"]}, {"cell_type": "code", "execution_count": 32, "id": "896d9bc4", "metadata": {}, "outputs": [], "source": ["trained_onx = train_session.get_trained_onnx()"]}, {"cell_type": "code", "execution_count": 33, "id": "360af856", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["opset: domain='' version=14\n", "input: name='x' type=dtype('float32') shape=(0, 2)\n", "init: name='init' type=dtype('float32') shape=(0,) -- array([-0.35357383, 0.6850407 ], dtype=float32)\n", "init: name='init_1' type=dtype('float32') shape=(0,) -- array([-1.916494 , 2.8799832], dtype=float32)\n", "init: name='init_2' type=dtype('int64') shape=(0,) -- array([1], dtype=int64)\n", "init: name='init_3' type=dtype('int64') shape=(0,) -- array([2], dtype=int64)\n", "init: name='init_5' type=dtype('int64') shape=(0,) -- array([0], dtype=int64)\n", "init: name='init_b10' type=dtype('float32') shape=(0,) -- array([-1.0036615], dtype=float32)\n", "init: name='init_b11' type=dtype('int64') shape=(0,) -- array([-1, 1], dtype=int64)\n", "Mul(x, x) -> out_mul_0\n", " Mul(out_mul_0, init) -> out_mul_0_1\n", "Mul(x, init_1) -> out_mul_0_2\n", " Add(out_mul_0_2, out_mul_0_1) -> out_add_0\n", " Slice(out_add_0, init_2, init_3, init_2) -> out_sli_0\n", " Squeeze(out_sli_0, init_2) -> out_squ_0\n", " Slice(out_add_0, init_5, init_2, init_2) -> out_sli_0_1\n", " Squeeze(out_sli_0_1, init_2) -> out_squ_0_1\n", " Add(out_squ_0_1, out_squ_0) -> out_add_0_1\n", " Add(out_add_0_1, init_b10) -> out_add_0_2\n", " Reshape(out_add_0_2, init_b11) -> y\n", "output: name='y' type=dtype('float32') shape=(0, 1)\n"]}], "source": ["print(onnx_simple_text_plot(trained_onx))"]}, {"cell_type": "code", "execution_count": 34, "id": "0aee38d3", "metadata": {}, "outputs": [{"data": {"text/plain": ["array([[-0.6123954],\n", " [-1.303561 ],\n", " [-2.0257921],\n", " [-1.2778704],\n", " [-0.9708453]], dtype=float32)"]}, "execution_count": 35, "metadata": {}, "output_type": "execute_result"}], "source": ["oinf = OnnxInference(trained_onx)\n", "oinf.run({'x': X})['y'][:5]"]}, {"cell_type": "code", "execution_count": 35, "id": "dc82cbb8", "metadata": {}, "outputs": [{"data": {"text/plain": ["array([[-0.58675164],\n", " [-1.3148587 ],\n", " [-2.0666485 ],\n", " [-1.272753 ],\n", " [-0.95404863]], dtype=float32)"]}, "execution_count": 36, "metadata": {}, "output_type": "execute_result"}], "source": ["y[:5]"]}, {"cell_type": "markdown", "id": "4a2ac16b", "metadata": {}, "source": ["It works."]}, {"cell_type": "markdown", "id": "854502c0", "metadata": {}, "source": ["## MLPregressor"]}, {"cell_type": "code", "execution_count": 36, "id": "c9c14cfc", "metadata": {}, "outputs": [], "source": ["import warnings\n", "import time\n", "import numpy\n", "import matplotlib.pyplot as plt\n", "from pandas import DataFrame\n", "from onnxruntime import get_device\n", "from sklearn.datasets import make_regression\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.neural_network import MLPRegressor\n", "from skl2onnx import to_onnx\n", "\n", "\n", "X, y = make_regression(1000, n_features=100, bias=2)\n", "X = X.astype(numpy.float32)\n", "y = y.astype(numpy.float32)\n", "X_train, X_test, y_train, y_test = train_test_split(X, y)"]}, {"cell_type": "code", "execution_count": 37, "id": "c36d7f35", "metadata": {}, "outputs": [], "source": ["batch_size = 15\n", "max_iter = 100\n", "\n", "nn = MLPRegressor(hidden_layer_sizes=(50, 10), max_iter=max_iter,\n", " solver='sgd', learning_rate_init=5e-5,\n", " n_iter_no_change=max_iter * 3, batch_size=batch_size,\n", " learning_rate=\"invscaling\",\n", " # default values\n", " momentum=0.9, nesterovs_momentum=True, power_t=0.5)\n", "\n", "with warnings.catch_warnings():\n", " warnings.simplefilter('ignore')\n", " nn.fit(X_train, y_train)"]}, {"cell_type": "markdown", "id": "4d5a10f9", "metadata": {}, "source": ["Conversion to ONNX"]}, {"cell_type": "code", "execution_count": 38, "id": "2e07fa9e", "metadata": {}, "outputs": [], "source": ["from onnxcustom.utils.onnx_helper import onnx_rename_weights\n", "onx = to_onnx(nn, X_train[:1].astype(numpy.float32), target_opset=15)\n", "onx = onnx_rename_weights(onx)"]}, {"cell_type": "code", "execution_count": 39, "id": "49050eb7", "metadata": {}, "outputs": [], "source": ["train_session = OrtGradientForwardBackwardOptimizer(\n", " onx, device='cpu', learning_rate=5e-5,\n", " warm_start=False, max_iter=max_iter, batch_size=batch_size)"]}, {"cell_type": "code", "execution_count": 40, "id": "d658b104", "metadata": {}, "outputs": [{"data": {"text/plain": ["OrtGradientForwardBackwardOptimizer(model_onnx='ir_version...', weights_to_train=\"['I0_coeff...\", loss_output_name='loss', max_iter=100, training_optimizer_name='SGDOptimizer', batch_size=15, learning_rate=LearningRateSGD(eta0=5e-05, alpha=0.0001, power_t=0.25, learning_rate='invscaling'), value=1.5811388300841898e-05, device='cpu', warm_start=False, verbose=0, validation_every=10, learning_loss=SquareLearningLoss(), enable_logging=False, weight_name=None, learning_penalty=NoLearningPenalty(), exc=True)"]}, "execution_count": 41, "metadata": {}, "output_type": "execute_result"}], "source": ["train_session.fit(X_train, y_train)"]}, {"cell_type": "code", "execution_count": 41, "id": "2c2e7ccc", "metadata": {}, "outputs": [{"data": {"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYMAAAD4CAYAAAAO9oqkAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAArzElEQVR4nO3deZRV5Z3u8e/vzDVQUECJCCiY4AgJKqKJ7RQ7iiY3qOnrkEFjVG4SM3Q6N8as2CuJHTu56W67k+60XtMhwazEIXEIN0HTxgxqEg2FouAEiKCFBRQFRRU1nel3/9i74AhVUNR0oPbzWeusc867p3dzlId32HubuyMiItEWK3cFRESk/BQGIiKiMBAREYWBiIigMBARESBR7goM1MSJE3369OnlroaIyCFl+fLlW929bs/yQzYMpk+fTn19fbmrISJySDGzDb2Vq5tIREQUBiIiojAQEREO4TEDEYmmXC5HQ0MDXV1d5a7KQS2TyTB16lSSyWS/1lcYiMghpaGhgTFjxjB9+nTMrNzVOSi5O83NzTQ0NDBjxox+baNuIhE5pHR1dTFhwgQFwT6YGRMmTDig1pPCQEQOOQqC/TvQP6PohcFz98KyH5S7FiIiB5XohcELD8Azi8tdCxGRg8p+w8DMFpnZFjNbVVJ2r5mtCF/rzWxFWD7dzDpLlt1Rss0pZrbSzNaa2XctbMOY2Xgze9TM1oTvtcNwnrsl0pDTLAQRGVrTp09n69ate5VXV1f3uc369euZNWvWcFar3/rTMvgRML+0wN0vd/c57j4HuB94oGTxqz3L3P0TJeW3A9cDM8NXzz5vAh5z95nAY+H34ZPIQF5hICJSar9TS939cTOb3tuy8F/3lwHv2dc+zGwyUOPuT4Xf7wIuBh4GFgDnhKsuBn4PfKk/lR8QhYHIqPH1//cCL77ZOqT7POGIGr76P07c5zrt7e1cdtllNDQ0UCgU+Pu///tdyzo7O7n00ku59NJLuf766/t93K6uLj75yU9SX19PIpHgtttu49xzz+WFF17gmmuuIZvNUiwWuf/++zniiCP2Ov7ll18+4HOGwV9ncCaw2d3XlJTNMLNngVbgZnd/ApgCNJSs0xCWAUxy98bw8yZgUl8HM7OFwEKAI488cmA1VhiIyCA98sgjHHHEEfzqV78CYMeOHXzpS19i586dXHHFFVx11VVcddVVB7TP733ve5gZK1eu5OWXX+b8889n9erV3HHHHXzuc5/jwx/+MNlslkKhwNKlS/c6/mANNgyuBO4u+d4IHOnuzWZ2CvCQme07Yku4u5uZ72P5ncCdAHPnzu1zvX1KZiDfPaBNReTgsr9/wQ+X2bNn84UvfIEvfelLvP/97+fMM88EYMGCBdx44418+MMfPuB9Pvnkk3zmM58B4LjjjuOoo45i9erVvOtd7+LWW2+loaGBSy+9lJkzZ/Z5/MEY8GwiM0sAlwL39pS5e7e7N4eflwOvAscAG4GpJZtPDcsANofdSD3dSVsGWqd+6WkZ+MCyRETkmGOO4ZlnnmH27NncfPPN3HLLLQCcccYZPPLII/gQ/v3yoQ99iCVLllBRUcFFF13Eb3/72z6PPxiDmVr618DL7r6r+8fM6swsHn4+mmCgeF3YDdRqZqeH4wxXAb8IN1sCXB1+vrqkfHgk0sG7WgciMkBvvvkmlZWVfOQjH+GLX/wizzzzDAC33HILtbW13HDDDQe8zzPPPJOf/OQnAKxevZrXX3+dY489lnXr1nH00Ufz2c9+lgULFvD888/3efzB6M/U0ruBPwPHmlmDmV0bLrqCt3YRAZwFPB9ONf058Al33xYu+xTwX8BaghbDw2H5t4D3mtkagoD51sBPpx8SFcF7vnNYDyMio9fKlSuZN28ec+bM4etf/zo333zzrmXf+c536Ozs5MYbbzygfX7qU5+iWCwye/ZsLr/8cn70ox+RTqe57777mDVrFnPmzGHVqlVcddVV+zz+QNlQNmdG0ty5c31ATzpb9gP41d/BF16BMYcPfcVEZFi99NJLHH/88eWuxiGhtz8rM1vu7nP3XDd6VyAnMsG7ZhSJiOwSvVtYJ8Mw0FXIIjJCVq5cyUc/+tG3lKXTaZ5++uky1Whv0QsDtQxEZITNnj2bFStWlLsa+xS5bqJ/+8MbwQeFgYjILpELg3iqZzaRwkBEpEfkwmDsmOAOgoWswkBEpEfkwqC2pgaAtp1tZa6JiMjBI3JhMH5sEAY72hQGIjK8/vEf/3G/6+zreQcjKXJhMLF2LACtbTvLXBMRGa3cnWKx2K8wOFhEbmrpYWEY7NypMBA55D18E2xaObT7PHw2XLj/u+LcdtttLFq0CIDrrruOiy++mAsuuIDTTjuN5cuXM2/ePDo7O5kzZw4nnnjirvsO9cXdufHGG3n44YcxM26++WYuv/xyGhsbufzyy2ltbSWfz3P77bfz7ne/m2uvvZb6+nrMjI9//ON8/vOfH9RpRy4MxtUETbKODoWBiAzM8uXL+eEPf8jTTz+Nu3Paaadx9tlns2bNGhYvXszpp58OwM9+9rN+X1/wwAMPsGLFCp577jm2bt3KqaeeyllnncVPf/pTLrjgAr7yla9QKBTo6OhgxYoVbNy4kVWrgqcRt7S0DPqcIhcGFt6orrOzvcw1EZFB68e/4IfDk08+ySWXXEJVVRUAl156KU888QRHHXXUriAYyD6vvPJK4vE4kyZN4uyzz2bZsmWceuqpfPzjHyeXy3HxxRczZ84cjj76aNatW8dnPvMZ3ve+93H++ecP+pwiN2ZAPEGeON2dumupiAytnnAYSmeddRaPP/44U6ZM4WMf+xh33XUXtbW1PPfcc5xzzjnccccdXHfddYM+TvTCAMjH0uS6O8pdDRE5RJ155pk89NBDdHR00N7ezoMPPtjr08aSySS5XK7f+7z33nspFAo0NTXx+OOPM2/ePDZs2MCkSZO4/vrrue6663jmmWfYunUrxWKRD37wg3zjG98YkucZRK6bCKAQS1Po7qRQdOIxK3d1ROQQc/LJJ/Oxj32MefPmAcEAcm1t7V7rLVy4kHe84x2cfPLJ+x1AvuSSS/jzn//MO9/5TsyMb3/72xx++OEsXryYf/qnfyKZTFJdXc1dd93Fxo0bueaaaygWiwB885vfHPQ5Re95BkD7t45l6c5jOPOLP+PwsZkhrpmIDCc9z6D/9DyD/UlWkLEsjTs0biAiAhHtJoonM6TJ0biji5PKXRkRGfWam5s577zz9ip/7LHHmDBhQhlqtLdIhkEiXUmaHG/s0M3qRA5F7o7ZoTPeN2HChBF/nsGBDgHst5vIzBaZ2RYzW1VS9jUz22hmK8LXRSXLvmxma83sFTO7oKR8fli21sxuKimfYWZPh+X3mlnqgM5gAOKpDBWxHI0t6iYSOdRkMhmam5sP+C+7KHF3mpubyWT6Pyban5bBj4D/AO7ao/xf3f2fSwvM7ATgCuBE4AjgN2Z2TLj4e8B7gQZgmZktcfcXgf8T7useM7sDuBa4vd9nMACWyDAmnqexVS0DkUPN1KlTaWhooKmpqdxVOahlMhmmTp3a7/X3Gwbu/riZTe/n/hYA97h7N/Cama0F5oXL1rr7OgAzuwdYYGYvAe8BPhSusxj4GsMcBiQrqIrl1TIQOQQlk0lmzJhR7mqMOoOZTfRpM3s+7EbqmWA7BXijZJ2GsKyv8glAi7vn9yjvlZktNLN6M6sf1L8KEmkysRybNGYgIgIMPAxuB94GzAEagX8Zqgrti7vf6e5z3X1uXV3dwHeUqCBDjs1t3RSK6ncUERlQGLj7ZncvuHsR+D67u4I2AtNKVp0alvVV3gyMM7PEHuXDK5Em5VkKRWdLm1oHIiIDCgMzm1zy9RKgZ6bREuAKM0ub2QxgJvAXYBkwM5w5lCIYZF7iwXSA3wF/E25/NfCLgdTpgCQyJDwLQKO6ikRE9j+AbGZ3A+cAE82sAfgqcI6ZzQEcWA/8LwB3f8HM7gNeBPLADe5eCPfzaeDXQBxY5O4vhIf4EnCPmX0DeBb4wVCdXJ+SGeKFIAQaW7rgyGE/oojIQa0/s4mu7KW4z7+w3f1W4NZeypcCS3spX8fubqaRkchgXiBOQbekEBEhqvcmSgQXYoxNFtRNJCJCxMPgqJqYppeKiBDZMEgDMKXaeFPdRCIiEQ2DZPAc5MlV0NTWXebKiIiUXzTDIGwZVMby5ArFMldGRKT8IhoGwZhBxnLkCroCWUQk0mFQYTlyebUMREQiHQYZcmTVTSQiEtEwSPZ0E2U1ZiAiQlTDIGwZpMlRdHTnUhGJvIiGQTCbKEVwszq1DkQk6iIaBsF1BpkwDDRuICJRF9Ew6GkZ5ADIa3qpiERcNMMgvAI56eomEhGBqIZBPAVAKgyDrK41EJGIi2YYmEEio5aBiEgommEAYRgEN6nTLSlEJOoiHQaJoloGIiIQ5TBIZkgWg5aBppaKSNTtNwzMbJGZbTGzVSVl/2RmL5vZ82b2oJmNC8unm1mnma0IX3eUbHOKma00s7Vm9l0zs7B8vJk9amZrwvfaYTjPvSUyJHrGDDSALCIR15+WwY+A+XuUPQrMcvd3AKuBL5cse9Xd54SvT5SU3w5cD8wMXz37vAl4zN1nAo+F34dfIk18VzeRxgxEJNr2Gwbu/jiwbY+y/3b3fPj1KWDqvvZhZpOBGnd/yt0duAu4OFy8AFgcfl5cUj68EhXEC8HzjzVmICJRNxRjBh8HHi75PsPMnjWzP5jZmWHZFKChZJ2GsAxgkrs3hp83AZP6OpCZLTSzejOrb2pqGlytE2niGjMQEQEGGQZm9hUgD/wkLGoEjnT3k4C/A35qZjX93V/Yauizz8bd73T3ue4+t66ubhA1B5IVxAo9U0sVBiISbYmBbmhmHwPeD5wX/iWOu3cD3eHn5Wb2KnAMsJG3diVNDcsANpvZZHdvDLuTtgy0TgckkVYYiIiEBtQyMLP5wI3AB9y9o6S8zszi4eejCQaK14XdQK1mdno4i+gq4BfhZkuAq8PPV5eUD69EZncY5DWALCLRtt+WgZndDZwDTDSzBuCrBLOH0sCj4QzRp8KZQ2cBt5hZDigCn3D3nsHnTxHMTKogGGPoGWf4FnCfmV0LbAAuG5Iz259EBgsHkDVmICJRt98wcPcreyn+QR/r3g/c38eyemBWL+XNwHn7q8eQS2SwvLqJREQg4lcgW15TS0VEIMphkMhghW7AddGZiERehMMgeNpZmpyeZyAikRfhMAiedlYdz6mbSEQiL8JhELQMquIFhYGIRF50wyB8DnJ1LK8xAxGJvOiGwa6WQV7XGYhI5EU4DDIAVMVyep6BiESewiCe15iBiERe5MOgUmMGIiIRDoNkEAYVMY0ZiIhENwx6WgaWVTeRiERe5MOgwjRmICKiMLCsnmcgIpEX+TDIWE5jBiISeREOg+Cis4zp3kQiIhEOg7BlgMJARCS6YRBPQCwRtgw0ZiAi0RbdMABIVJAmq+cZiEjk9SsMzGyRmW0xs1UlZePN7FEzWxO+14blZmbfNbO1Zva8mZ1css3V4fprzOzqkvJTzGxluM13zcyG8iT7lEiTRtcZiIj0t2XwI2D+HmU3AY+5+0zgsfA7wIXAzPC1ELgdgvAAvgqcBswDvtoTIOE615dst+exhkciQ0pjBiIi/QsDd38c2LZH8QJgcfh5MXBxSfldHngKGGdmk4ELgEfdfZu7bwceBeaHy2rc/Sl3d+Cukn0Nr2SGlGc1ZiAikTeYMYNJ7t4Yft4ETAo/TwHeKFmvISzbV3lDL+V7MbOFZlZvZvVNTU2DqHookSFFVtcZiEjkDckAcvgv+mH/57W73+nuc919bl1d3eB3mMiQ8m5yhSLBKYiIRNNgwmBz2MVD+L4lLN8ITCtZb2pYtq/yqb2UD79EhmQxhzsUigoDEYmuwYTBEqBnRtDVwC9Kyq8KZxWdDuwIu5N+DZxvZrXhwPH5wK/DZa1mdno4i+iqkn0Nr0SahGcBNG4gIpGW6M9KZnY3cA4w0cwaCGYFfQu4z8yuBTYAl4WrLwUuAtYCHcA1AO6+zcz+AVgWrneLu/cMSn+KYMZSBfBw+Bp+yQqS3g1AtlCkgviIHFZE5GDTrzBw9yv7WHReL+s6cEMf+1kELOqlvB6Y1Z+6DKlEmkQxCANNLxWRKIv2FcipapL5dkBhICLRFu0wqBhHKt8KuJ5pICKRFu0wyIwlXsyRRs80EJFoi3gYjANgLO3qJhKRSIt2GFSMA2CstZPX1FIRibBoh8GulsFOdROJSKRFOwxKWgbqJhKRKIt2GGjMQEQEUBgAUGMdCgMRibSIh8FYIGgZZHWdgYhEWLTDIJ6gmKzWmIGIRF60wwAoZsYqDEQk8iIfBp4eR40GkEUk4hQGFeOosQ6yuuhMRCIs8mFgmbHB1NK8WgYiEl0Kg4pxGjMQkchTGFTW6qIzEYm8yIdBrGIcldZNPpctd1VERMpmwGFgZsea2YqSV6uZ/a2Zfc3MNpaUX1SyzZfNbK2ZvWJmF5SUzw/L1prZTYM9qQM6j4paAOLdLSN5WBGRg0q/noHcG3d/BZgDYGZxYCPwIHAN8K/u/s+l65vZCcAVwInAEcBvzOyYcPH3gPcCDcAyM1vi7i8OtG4HJLwKOZ5tHZHDiYgcjAYcBns4D3jV3TeYWV/rLADucfdu4DUzWwvMC5etdfd1AGZ2T7juCIXBOACS2R0jcjgRkYPRUI0ZXAHcXfL902b2vJktMrPasGwK8EbJOg1hWV/lIyO8jXVCLQMRibBBh4GZpYAPAD8Li24H3kbQhdQI/Mtgj1FyrIVmVm9m9U1NTUOz07BlkMopDEQkuoaiZXAh8Iy7bwZw983uXnD3IvB9dncFbQSmlWw3NSzrq3wv7n6nu89197l1dXVDUHV2tQxSeYWBiETXUITBlZR0EZnZ5JJllwCrws9LgCvMLG1mM4CZwF+AZcBMM5sRtjKuCNcdGeEAclphICIRNqgBZDOrIpgF9L9Kir9tZnMAB9b3LHP3F8zsPoKB4Txwg7sXwv18Gvg1EAcWufsLg6nXAUmk6SJNOrdzxA4pInKwGVQYuHs7MGGPso/uY/1bgVt7KV8KLB1MXQajPVZNptBWrsOLiJRd5K9AhiAMKhQGIhJhCgOgI1ZNZVFhICLRpTAAOuNjqCxozEBEokthQBAGVa4wEJHoUhgA3YkxVBcVBiISXQoDoCs+hio6oFgod1VERMpCYQB0J2uCD126WZ2IRJPCAMj2hEHn9vJWRESkTBQGQG5Xy6ClrPUQESkXhQGQTwb3J1I3kYhElcIAyKd7uolayloPEZFyURgAhVTQMnCFgYhElMIAKKaDMChqAFlEIkphAFiygm5PUOxQGIhINCkMgGQiTitVeEdLuasiIlIWCgMgmYixw6soajaRiESUwgBIxY1WKjFdZyAiEaUwAJLxoGWgi85EJKoUBoRhQBUxdROJSEQNOgzMbL2ZrTSzFWZWH5aNN7NHzWxN+F4blpuZfdfM1prZ82Z2csl+rg7XX2NmVw+2XgciGTdavJpY5zZwH8lDi4gcFIaqZXCuu89x97nh95uAx9x9JvBY+B3gQmBm+FoI3A5BeABfBU4D5gFf7QmQkZCMx3jZjySea4OmV0bqsCIiB43h6iZaACwOPy8GLi4pv8sDTwHjzGwycAHwqLtvc/ftwKPA/GGq216S8Rh/LJ4YfHntDyN1WBGRg8ZQhIED/21my81sYVg2yd0bw8+bgEnh5ynAGyXbNoRlfZW/hZktNLN6M6tvamoagqoHkvEYb/gkuqqmwjqFgYhEz1CEwV+5+8kEXUA3mNlZpQvd3QkCY9Dc/U53n+vuc+vq6oZilwCkEgbA9sPfBeufhEJ+yPYtInIoGHQYuPvG8H0L8CBBn//msPuH8H1LuPpGYFrJ5lPDsr7KR0QyHvwxbK17F3TvgMbnRurQIiIHhUGFgZlVmdmYns/A+cAqYAnQMyPoauAX4eclwFXhrKLTgR1hd9KvgfPNrDYcOD4/LBsRPWHQNHFeUPDa70fq0CIiB4XEILefBDxoZj37+qm7P2Jmy4D7zOxaYANwWbj+UuAiYC3QAVwD4O7bzOwfgGXhere4+7ZB1q3fesJgZ3ICHHZCMG5w5hdG6vAiImU3qDBw93XAO3spbwbO66XcgRv62NciYNFg6jNQqTAMcvkizDgblv8Qcl2QzJSjOiIiI05XIAPJcAA5VyjC0WdDvgveeLrMtRIRGTkKA3Z3E+UKRTjqDLC4rjcQkUhRGLA7DLIFh0wNTDlZ1xuISKQoDCgZMygUg4Kjz4E3n4H2reWrlIjICFIYENyoDsIBZIATFoAX4cWHylcpEZERpDAA4jHDrKRlMGkW1B0HK39e3oqJiIwQhQFgZiTjsWDMICiA2X8Dr/8ZWt7Y98YiIqOAwiCUisd2twwAZn0weF91f3kqJCIyghQGoWTc3hoG44+GKXPVVSQikaAwCCX3bBkAzP6fsHklbHm5PJUSERkhCoNQMh4jm9/jTtsnXgIWg1VqHYjI6KYwCKUSvbQMxkyCGWcFXUV6NrKIjGIKg9BeYwY9Zv0NbH8NGleMeJ1EREaKwiDU65gBwLEXBfcqeun/jXylRERGiMIg9JbrDEpVTYDpZygMRGRUUxiEUvHY7ttR7On4D8DW1dD0yshWSkRkhCgMQslEH2MGAMe9P3h/acnIVUhEZAQpDEJ9jhkA1EyGqfPgRYWBiIxOCoNQn2MGPY7/H7Dpedi+fsTqJCIyUgYcBmY2zcx+Z2YvmtkLZva5sPxrZrbRzFaEr4tKtvmyma01s1fM7IKS8vlh2Vozu2lwpzQwe92baE/H93QV/XJkKiQiMoIG0zLIA19w9xOA04EbzOyEcNm/uvuc8LUUIFx2BXAiMB/4TzOLm1kc+B5wIXACcGXJfkZMn9cZ9Bh/NEyarVlFIjIqDTgM3L3R3Z8JP7cBLwFT9rHJAuAed+9299eAtcC88LXW3de5exa4J1x3RCX3NZuoxwkfgDeehnW/H5E6iYiMlCEZMzCz6cBJwNNh0afN7HkzW2RmtWHZFKD04QANYVlf5b0dZ6GZ1ZtZfVNT01BUfZdkYj9jBgAnXw11x8KPL4E//YduUSEio8agw8DMqoH7gb9191bgduBtwBygEfiXwR6jh7vf6e5z3X1uXV3dUO0WCMYMunMFfF9/wY+ZBNf9Bo57H/z3V+CB6yGfHdJ6iIiUw6DCwMySBEHwE3d/AMDdN7t7wd2LwPcJuoEANgLTSjafGpb1VT6ijjt8DG3deVZtbN33iukxcNmP4T03w8qfwV/uHJkKiogMo8HMJjLgB8BL7n5bSfnkktUuAVaFn5cAV5hZ2sxmADOBvwDLgJlmNsPMUgSDzCM+of/CWZNJxWM8+Gw/csgMzvoivO098MQ/Q9eO4a+giMgwGkzL4Azgo8B79phG+m0zW2lmzwPnAp8HcPcXgPuAF4FHgBvCFkQe+DTwa4JB6PvCdUfU2Mok5x5Xx5Ln3iS/r1lFpf76a9C5Hf74nWGtm4jIcEsMdEN3fxKwXhYt3cc2twK39lK+dF/bjZRLTprCr1/YzJ9ebeasY/oxJjH5ncHT0P78n3Dq9cGVyiIihyBdgVzinGMPoyaT4KH+dBX1OPcrUMzDH741fBUTERlmCoMSmWSci2ZP5pEXNtGRzfdvo/EzYO7H4Zkfw4Y/DW8FRUSGicJgDxefNIWObIFHX9zc/43OvhFqj4LFH4C/fF/XH4jIIUdhsId508dzxNjMgXUVVU2E638XzC5a+r/hoU/BppWQ6xq+ioqIDKEBDyCPVrGYcfFJU7jjD6/yl9e2MW/G+P5tWDEOrrwHHv82/P6b8NxPwWJQOx2OfDfMfC+87VzIjB3O6ouIDIjt84rbg9jcuXO9vr5+WPbd1pXjA//xR9q78/zqs2dSNyZ9YDtofhUaV0DTati8CtY/EVyLYHE49kI4429h2qnDUXURkX0ys+XuPnevcoVB715qbOXi7/2RU46q5cfXnkY81tss2n4q5GFjPbz8S3jmriAYjnw3HD4bOrZCRzMccXJwIVuqcuhOQkRkD32FgcYM+nD85Br+4eJZ/OnVZv7tN6sHt7N4Ao48Hc7/Bnz+RZj/LWjdCM/dA28+C50t8ORtcPu7dEdUESkLjRnsw2Vzp1G/fhv//tu1dOUK3Dj/OJLxQeZnuhpO/2TwKrX+SVjyWbhrAcxbGARGLD64Y4mI9JPCYD++cfFsMsk433/iNZ59vYV//9BJTB5bMfQHmv5X8Mk/wmO3wFP/Ce1b4ZL/C4nU0B9LRGQP6ibaj1Qixi0LZvHvV57ES42tXPSdJ1j05Gt05QpDf7BkBcz/Jrz3FnjhAbj3I5DrHPrjiIjsQQPIB2Dtlp38/UOr+PO6Zg6vyXDDuW/j0pOnUpUehgZW/SL45d8FU1On/xVMnQtjp0LHNti5JQiOd1wedDv12NkEr/0hmL5aNREqxu/uakpkgjIRiTTNJhpCf3p1K7f992rqN2ynMhXnwlmT+eApU5g3fTyJwY4plHrpl7D8h7BxeXB31D1VToB3fxbefh4s+wE8dzfk93Gh28lXwQXffGuAiEikKAyGmLtTv2E79y9v4JfPN7KzO09NJsEZb5/ImTPrOHV6LW+rqyY2mCmpuw8G21+Dts1QVRf8C3/ravj9t+DVx4J14ml45xVwytVQLEJ7E3RuAw9vx735RXj6jqClcemdMG1en4cTkdFLYTCMOrMFfvvyFh5f3cTja5po3BH867w6neAdU8cya8pYjp88huMOr2HGxCoyySGcJfTGsqDlcOIlwWM592X9H+HBT0BrAxx2Ihw+CybNgrFTglZGZmwQGuufgNf/DGMmw3Hvh+PfD+OOHLo6i0jZKAxGiLuzbms7z77ewoo3trPijRZWb9pJNnxgjhlMrslw1IQqpo2vYMq4SqbUVnDE2AyTxmY4vCYzPGMQPbp2BM9f2FgPm1bBzk17r1NRG1wU17IhuIIaIDMOUlWQrIR4cvfN+ConwJSTYMopMPGYcJ2qoCsqucesK/egpaIpsyJlozAoo1yhyPqt7bzY2Mr6rR1saG5nfXM7Dds72dLWvdf61ekEE6tTTKxOM6E6xfiqFOMqU9RWJsP34HNNRZKaTJIxmQSVqTjBk0gPUHsz7NwcXAXduQ3GHx20GmLh2Efzq/DKUmh5HbIdkGsPnt/Qo7URNj0Pheze+05UBMGSzAQh1LUjuF/TjLPhuPfBMfN7fyBQrivYplRPV5nFgoHx9JggWUXkgCgMDlLd+QKNLV28uaOTza1dNO7oYktrN1t3Bq/mnVm2d+Ro6ciSL/b9W8VjRnU6QXU6QVU6TlU6QVUqCImqdPBemYpTEZZVJONUhO+ZZJxMMha8J3Z/TidipBNxUokYqUSs71ty5LNBC6JlQxgYHdDdGgx6d24Ppsdmxgati1wnrH4Ytq8Ptq2cABNmBjOlWjfC1jXBLTqqDgtu11F3bBBIDX956yB6LAkTZwYtkimnBPvvbguOu6MhGFPZuiaoS+WEIEBSlUGQFQuQroGj3hXM1Dr8ncFV4hCETteOYMylozkYkC/kghZN9WFQMzUYs1EQySFKYXCIc3fauvPs6MixvSNLS0eOtq48rV05dnTm2NmVZ2d38L0zW6A9W6C9Ox+8snnauwt0ZPN05fr5fOdexGNGOgyGZDxGKh576/dEjGTcdi1L9PI5+G4c3r2Oo1uXUde1gdrO9Yzp3kx7xWTaqo6iq3IyYzo3Mq71Zca0raOzaho7Jp5E28Q5xOIJUtkWUtntVG1/maqtz5HobnlLPYvJSvK1b6cwfiaeribetZ1453Ys3xGESDxBbOdmrHnN7o1iieDlxd5bOW/5g0hDzRFBgNVMCbrD4slg+3x3ECC5zqAVk0gHy7p3BiHXvjUIp2nzYNppMHYaEHaf9fy/aBbsp3MbdGyHQncQZpUTgv21bw1ac7mOIBAPOxGq64KJA10tu0MzFod4Cion6uJF2eWgDwMzmw98B4gD/+Xu+3yOZNTCYKgUi05XvkBntkBHtkBnrkB3rkhXvkBXrkBXrkh3fvd7Nl+kO18M34N1s4UiucLu8lwheA/Kfdf3fM/nwt6f88Vg3cFzptkW0uTY6RW0U0EbFfT+eO63qqOF0+MvcWxsIwkrkDDHMLbZWLbbOFqthpylyMdSmBmHsZ1JNDPJtzLJt3KYb6XOt5L0HHEKJMiTI0nW0uQsheEkyZPwHN2xCtri49gZG0ttoYlp2XXEGHgw76k7Xkmy0NXrPh2jIzWBjtREEsVuMrkW0vlWnBj5eAX5RAW4E/ccsWKOYixFLlFNLjmGXHIM+UQ1uUQVMZxUroV0djvxwu4pzMVEJR3VR9JRdRS59FgyXVvIdGwikd9JtmISXVVTyGUmkCh0kcjvJJ7vpJCqIZ8ZTyFVQ6zQSTK3k0SulUT3DuK5NhK5NnKVk+iqPYbucTPxRIZ4IUus2A2xOJ6sxpOVxLM7SLW+TqrtdWLFHIWK8RQyE/DMOIqZsXhqDCQzmIMRBK4ZWCxGvHsHyc3Pkdr8LImW9Xi6hmJFLV55GIWJx1CYeDzFCcdgFsOK3cTyWci1EetuJZbdCRXjKFZPxipqsZhh2K4GoxnBd4pQyGGFLBZPYIkMWCxYXqbW5UEdBmYWB1YD7wUagGXAle7+Yl/bKAwOfe5OvuhBUBTDkCgUyRWdQklZoRi8csUixeLubQruFMJQ6SnvWTdYFpQVS957ynteRS9ZFi53Z9fn0vJi0Sk6bykvelBWDPfrvX0O13Hv2QbSxQ7enlvN2GILRWIUHHZloztZj7Pdx7Cdaro9QY23MdbbSHk3zYxli4+ls5jgbbaRmbzONDazwytp9hp2eBWOkbACKfLU0cJka2aSbaeDNC0+hhaCv9wr6aKSborEyJIgR4IUOcZYJ2Npp9o6qaaTMdaBY2zzMbR4NR1k6KnuGDo4yrYwxbYSs6Dum3w8O6lkkm1jgrW95XcvuBG33v/e6fQUrVTS7hkm2zYqbD+ttCGw3at51Y+gik5qbScTaCVp/b/DQLcn6SRFnjhFYsQpkCZHhiwJ2zucuz34cy5igJEnRjcpukniGBmyZMiSpECOBNlw3QQFkuRJUqD5I49y5Mx3DOh8+wqDg+XeRPOAte6+DsDM7gEWAH2GgRz6zCzsVoIKojjD6IJh2av77iAqOji7A84Jyr0YlPcEmYfrFYu71+8pB8iH+6p2pzIMNodd67U5vFzowrp3UsjUUsRwhzeAN7LtxDubKSQrySeqKVoSy7WT6NpGrLuFQqKKfGoM2UQ1xVh6Vx03FAqkOxqo3PEqVixQiKfJWxIrFogXOojnOsglqmivmkZ7xREUYykSuVaS3dtIdreQCFsY8UIQdm6GY7u64/KxNNtqTqCtYioeFjtghSxjOl5nXOtqato34BYjb0kKsRTZeBXZ5BiysQrS+TYquzdT2dVEvJjFvIB5nqIlKMRS5C1FIZakSIJC2AWZKHQT925ixXzQSqFIrJgn7tld+2iLZcjHMhQsQczzxIo5Yl6gYEmKFqcQS/L2mtoh/+/mYAmDKQT/3fRoAE7bcyUzWwgsBDjySM17F+mNWdBdEetHV9nQqgEO66V8HMH/4qUmAP35f3gyUI4HQc0EzivDccvnkLpRnbvf6e5z3X1uXV1duasjIjJqHCxhsBGYVvJ9algmIiIj4GAJg2XATDObYWYp4ApgSZnrJCISGQfFmIG7583s08CvCaaWLnL3F8pcLRGRyDgowgDA3ZcCS8tdDxGRKDpYuolERKSMFAYiIqIwEBGRg+R2FANhZk3AhgFuPhHYOoTVOVRE8byjeM4QzfPWOffPUe6+14Vah2wYDIaZ1fd2b47RLornHcVzhmiet855cNRNJCIiCgMREYluGNxZ7gqUSRTPO4rnDNE8b53zIERyzEBERN4qqi0DEREpoTAQEZHohYGZzTezV8xsrZndVO76DAczm2ZmvzOzF83sBTP7XFg+3sweNbM14fvQPy6pzMwsbmbPmtkvw+8zzOzp8Pe+N7wr7qhiZuPM7Odm9rKZvWRm7xrtv7WZfT78b3uVmd1tZpnR+Fub2SIz22Jmq0rKev1tLfDd8PyfN7OTD+RYkQqD8FnL3wMuBE4ArjSzE8pbq2GRB77g7icApwM3hOd5E/CYu88EHgu/jzafA14q+f5/gH9197cD24Fry1Kr4fUd4BF3Pw54J8H5j9rf2symAJ8F5rr7LII7HV/B6PytfwTM36Osr9/2QoJHtM0keCLk7QdyoEiFASXPWnb3LNDzrOVRxd0b3f2Z8HMbwV8OUwjOdXG42mLg4rJUcJiY2VTgfcB/hd8NeA/w83CV0XjOY4GzgB8AuHvW3VsY5b81wR2XK8wsAVQCjYzC39rdHwe27VHc12+7ALjLA08B48xscn+PFbUw6O1Zy3s+nHVUMbPpwEnA08Akd28MF20CJpWrXsPk34AbgWL4fQLQ4u758Pto/L1nAE3AD8Pusf8ysypG8W/t7huBfwZeJwiBHcByRv9v3aOv33ZQf79FLQwixcyqgfuBv3X31tJlHswpHjXzis3s/cAWd19e7rqMsARwMnC7u58EtLNHl9Ao/K1rCf4VPAM4Aqhi766USBjK3zZqYRCZZy2bWZIgCH7i7g+ExZt7mo3h+5Zy1W8YnAF8wMzWE3T/vYegL31c2JUAo/P3bgAa3P3p8PvPCcJhNP/Wfw285u5N7p4DHiD4/Uf7b92jr992UH+/RS0MIvGs5bCv/AfAS+5+W8miJcDV4eergV+MdN2Gi7t/2d2nuvt0gt/1t+7+YeB3wN+Eq42qcwZw903AG2Z2bFh0HvAio/i3JugeOt3MKsP/1nvOeVT/1iX6+m2XAFeFs4pOB3aUdCftn7tH6gVcBKwGXgW+Uu76DNM5/hVB0/F5YEX4uoigD/0xYA3wG2B8ues6TOd/DvDL8PPRwF+AtcDPgHS56zcM5zsHqA9/74eA2tH+WwNfB14GVgE/BtKj8bcG7iYYF8kRtAKv7eu3BYxgtuSrwEqC2Vb9PpZuRyEiIpHrJhIRkV4oDERERGEgIiIKAxERQWEgIiIoDEREBIWBiIgA/x9VhAPqUIBryQAAAABJRU5ErkJggg==\n", "text/plain": ["
"]}, "metadata": {"needs_background": "light"}, "output_type": "display_data"}], "source": ["pandas.DataFrame(dict(skl_loss=nn.loss_curve_, ort_loss=train_session.train_losses_)).plot();"]}, {"cell_type": "code", "execution_count": 42, "id": "018f9ae3", "metadata": {}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["C:\\Python395_x64\\lib\\site-packages\\sklearn\\neural_network\\_multilayer_perceptron.py:692: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (100) reached and the optimization hasn't converged yet.\n", " warnings.warn(\n"]}, {"name": "stdout", "output_type": "stream", "text": ["1.98 s \u00b1 0 ns per loop (mean \u00b1 std. dev. of 1 run, 1 loop each)\n"]}], "source": ["%timeit -n 1 -r 1 nn.fit(X_train, y_train)"]}, {"cell_type": "code", "execution_count": 43, "id": "1a956825", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["1.88 s \u00b1 0 ns per loop (mean \u00b1 std. dev. of 1 run, 1 loop each)\n"]}], "source": ["%timeit -n 1 -r 1 train_session.fit(X_train, y_train)"]}, {"cell_type": "markdown", "id": "bb7d75d9", "metadata": {}, "source": ["## Not exactly the same: Nesterov?"]}, {"cell_type": "code", "execution_count": 44, "id": "534dfc61", "metadata": {}, "outputs": [], "source": ["from onnxcustom.training.sgd_learning_rate import LearningRateSGDNesterov\n", "\n", "train_session2 = OrtGradientForwardBackwardOptimizer(\n", " onx, device='cpu', warm_start=False, max_iter=max_iter, batch_size=batch_size,\n", " learning_rate=LearningRateSGDNesterov(1e-5, nesterov=True, momentum=0.9))"]}, {"cell_type": "code", "execution_count": 45, "id": "6e209570", "metadata": {}, "outputs": [{"data": {"text/plain": ["OrtGradientForwardBackwardOptimizer(model_onnx='ir_version...', weights_to_train=\"['I0_coeff...\", loss_output_name='loss', max_iter=100, training_optimizer_name='SGDOptimizer', batch_size=15, learning_rate=LearningRateSGDNesterov(eta0=1e-05, alpha=0.0001, power_t=0.25, learning_rate='invscaling', momentum=0.9, nesterov=True), value=3.162277660168379e-06, device='cpu', warm_start=False, verbose=0, validation_every=10, learning_loss=SquareLearningLoss(), enable_logging=False, weight_name=None, learning_penalty=NoLearningPenalty(), exc=True)"]}, "execution_count": 46, "metadata": {}, "output_type": "execute_result"}], "source": ["train_session2.fit(X_train, y_train)"]}, {"cell_type": "code", "execution_count": 46, "id": "788fee10", "metadata": {}, "outputs": [{"data": {"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYMAAAD4CAYAAAAO9oqkAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAAu10lEQVR4nO3deXwc5Z3n8c+vqqtvSZZkWdiW8QE25jLGGGOSMUcI5kgyHMlgCBMD4cjmACY7G8jOMK9sSJiZMFlmkt1M2GxCgB1yQEISh3CEIcwQkuAgE+MTbGMMlrGxLetWn1XP/lElq20kW9bVsur3fr36pe6nq6ueUtn91fM8VfWIMQallFLhZpW7AkoppcpPw0AppZSGgVJKKQ0DpZRSaBgopZQCIuWuwGBNnDjRzJgxo9zVUEqpo8qqVav2GmPqDi4/asNgxowZNDY2lrsaSil1VBGRt/oq124ipZRSGgZKKaU0DJRSSnEUjxkopca3QqFAU1MT2Wy23FU5KsXjcRoaGnAcZ0DLaxgopcakpqYmKioqmDFjBiJS7uocVYwxNDc309TUxMyZMwf0Ge0mUkqNSdlsltraWg2CQRARamtrj6hVpWGglBqzNAgG70h/d6ELgx9s/AFPv/l0uauhlFJjSujC4Cebf8LT2zQMlFKqVOjCIO2k6Sx0lrsaSqmj0IwZM9i7d+97ytPpdL+f2bZtG6eccspIVmtYhC4MUk6KrnxXuauhlFJjSuhOLU07aZo6mspdDaXUEfjyL9ez4Z32YV3nSVMq+dJHTu73/a6uLq666iqamppwXZe/+7u/2/9eJpPhyiuv5Morr+Tmm28e8Daz2Syf/vSnaWxsJBKJcN9993H++eezfv16brjhBvL5PJ7n8dOf/pQpU6a8Z/vLli0b0j4fSujCIOWk6Cpoy0ApdWhPP/00U6ZM4Ve/+hUAbW1t3HnnnXR2dnL11VezfPlyli9ffkTr/Na3voWIsHbtWl577TWWLl3Kpk2buP/++7n99tu59tpryefzuK7Lk08++Z7tj6TQhYGOGSh19DnUX/Aj5dRTT+Wv//qvufPOO/nwhz/MkiVLALjsssu44447uPbaa494nS+++CK33norAHPnzmX69Ols2rSJs88+m3vuuYempiauvPJKZs+e3e/2R0r4xgyiKTLFDK7nlrsqSqkxbM6cObzyyiuceuqp3HXXXdx9990AvP/97+fpp5/GGDNs2/r4xz/OihUrSCQSXHrppfzmN7/pd/sjJXRhkHb8Uf+uonYVKaX6984775BMJvnLv/xLvvCFL/DKK68AcPfdd1NdXc1nP/vZI17nkiVLeOSRRwDYtGkTb7/9NieccAJbt25l1qxZ3HbbbVx22WWsWbOm3+2PlMOGgYhME5HnRWSDiKwXkduD8hoReVZENgc/q4NyEZFvisgWEVkjIgtK1nVdsPxmEbmupPwMEVkbfOabMoKXHe4PAz2jSCl1CGvXrmXRokXMnz+fL3/5y9x111373/vGN75BJpPhjjvuOKJ1fuYzn8HzPE499VSWLVvGgw8+SCwW49FHH+WUU05h/vz5rFu3juXLlx9y+yPCGHPIBzAZWBA8rwA2AScB9wJfDMq/CHwteH4p8BQgwGJgZVBeA2wNflYHz6uD9/4YLCvBZy85XL3OOOMMMxjPvPmMOeXBU8ymfZsG9Xml1OjYsGFDuatw1Ovrdwg0mj6+Uw/bMjDG7DTGvBI87wA2AlOBy4CHgsUeAi4Pnl8GPBxs9yVggohMBi4CnjXG7DPGtADPAhcH71UaY14KKvpwybqG3f6WgZ5RpJRS+x3R2UQiMgM4HVgJ1BtjdgZv7QLqg+dTge0lH2sKyg5V3tRHeV/bvwW4BeDYY489kqrvl4qmAPSMIqXUsFq7di2f+MQnDiiLxWKsXLmyTDU6MgMOAxFJAz8F/soY017arW+MMSIyfEPr/TDGfAf4DsDChQsHtb2eloGGgVJqOJ166qmsXr263NUYtAGdTSQiDn4QPGKMeTwofjfo4iH4uTso3wFMK/l4Q1B2qPKGPspHRMrxWwY6gKyUUr0GcjaRAN8DNhpj7it5awXQc0bQdcAvSsqXB2cVLQbagu6kZ4ClIlIdnHm0FHgmeK9dRBYH21pesq5hpy0DpZR6r4F0E70f+ASwVkRWB2V/A/wj8KiI3Ai8BVwVvPck/hlFW4Bu4AYAY8w+EfkK8HKw3N3GmH3B888ADwIJ/LOJnhr8Lh1a0kkCOoCslFKlDhsGxpgX8U/57MsFfSxvgD6vxjDGPAA80Ed5IzAq93i1xCIZSWrLQCmlSoTuCmSeu5s0oi0DpdSw+Pu///vDLnOo+Q7GivCFwaZnSBWLdOa1ZaCUGjxjDJ7nDSgMjgahu2sp0TRpk9WWgVJHk6e+CLvWDu86jzkVLvnHQy5y33338cADfs/2TTfdxOWXX85FF13EWWedxapVq1i0aBGZTIb58+dz8skn77/vUH+MMdxxxx089dRTiAh33XUXy5YtY+fOnSxbtoz29naKxSLf/va3ed/73seNN95IY2MjIsInP/lJPv/5zw/b7h8sfGEQS5Mq7NExA6XUIa1atYrvf//7rFy5EmMMZ511Fueeey6bN2/moYceYvHixQA89thjA76+4PHHH2f16tW8+uqr7N27lzPPPJNzzjmHH/zgB1x00UX87d/+La7r0t3dzerVq9mxYwfr1q0DoLW1dYT21BfCMKggnXXZqy0DpY4eh/kLfiS8+OKLXHHFFaRS/rVJV155Jb/97W+ZPn36/iAYzDqvueYabNumvr6ec889l5dffpkzzzyTT37ykxQKBS6//HLmz5/PrFmz2Lp1K7feeisf+tCHWLp06XDu3nuEb8wgmiblFrRloJQalJ5wGE7nnHMOL7zwAlOnTuX666/n4Ycfprq6mldffZXzzjuP+++/n5tuumnYt1sqfGEQqyBdLOgVyEqpQ1qyZAk///nP6e7upquri5/97Gd9zjbmOA6FQmHA6/zxj3+M67rs2bOHF154gUWLFvHWW29RX1/PzTffzE033cQrr7zC3r178TyPj370o3z1q18d8fkMQtlNlCrk6Cx0YoxhBKdOUEodxRYsWMD111/PokWLAH8Aubq6+j3L3XLLLcybN48FCxYcdgD5iiuu4A9/+AOnnXYaIsK9997LMcccw0MPPcQ//dM/4TgO6XSahx9+mB07dnDDDTfgeR4A//AP/zD8O1lCzDBO3TaaFi5caBobG4/8g7/7Jt//473cV1PNyo+v3H9FslJqbNm4cSMnnnhiuatxVOvrdygiq4wxCw9eNnTdRBuaPVKeH4A6bqCUUr7QdRP9fGM78y2/2dVZ6GQSk8pcI6XUeNDc3MwFF7znDj0899xz1NbWlqFGRyZ0YeA6adIFPwx0EFkpNVxqa2vH/3wG44nnpLSbSCmlDhK6MDDRCtImaBnohWdKKQWEMAyIVZDyescMlFJKhTAMJJYmHXQTactAKaV8oQsDO17SMtDbWCulhmi05zP4whe+wNy5c5k3bx5XXHHFsN3ALnRhEImlEGMRw9KWgVJq0Mo1n8GFF17IunXrWLNmDXPmzBm2K5NDd2ppMhahizgpieiYgVJHia/98Wu8tu+1YV3n3Jq53LnozkMuMxbnMyi9e+nixYv5yU9+MvRfBmEMg6hNJwlSWBoGSql+HQ3zGTzwwAMsW7ZsWPY3dGGQiNp0mgQpowPISh0tDvcX/EgY6/MZ3HPPPUQiEa699toh7yuEcMwgGbXpIk7SMzqArJQ6YmNhPoMHH3yQJ554gkceeWTY7rwcujBIOBG/ZeB52jJQSvVrrM5n8PTTT3PvvfeyYsUKksnhu+ty6LqJesYM0m4HW3XMQCnVj7E6n8HnPvc5crkcF154IeAPIt9///1D3t/QzWfQuG0fb373OjZMfYN/r67lt1f/dgRqp5QaKp3PYOh0PoNDSAQtg4pifv9sZ0opFXYh7CaK7A+Dolck7+WJ2bFyV0spdZTT+QyOMsmoTZeJM9FzAf+WFLGEhoFSY9HRNE/5WJvP4Eh7PULbTaS3sVZqbIvH4zQ3N2tX7iAYY2hubiYejw/4M+FrGTj+RWdJneBGqTGtoaGBpqYm9uzZU+6qHJXi8TgNDQ0DXj50YRCxLXJWkrSnLQOlxjLHcZg5c2a5qxEaoesmAihGkvvnNNCrkJVSKqxh4KR1tjOllCoRyjDwoikdQFZKqRKhDAMTTZPSAWSllNovnGHgVBI3BhvRloFSShHSMLBiKQxCynI0DJRSigGEgYg8ICK7RWRdSdn/EJEdIrI6eFxa8t5/F5EtIvK6iFxUUn5xULZFRL5YUj5TRFYG5T8Wkehw7mBfkrEIGeKksTUMlFKKgbUMHgQu7qP8n40x84PHkwAichJwNXBy8Jl/FRFbRGzgW8AlwEnANcGyAF8L1nU80ALcOJQdGohE1KaLpD/1pZ5aqpRShw8DY8wLwL4Bru8y4EfGmJwx5k1gC7AoeGwxxmw1xuSBHwGXiX/TkQ8APTM6PwRcfmS7cOT8OQ3ipNEBZKWUgqGNGXxORNYE3Ug9Mz5MBbaXLNMUlPVXXgu0GmOKB5WPqGQ0QqeJU+lBe759pDenlFJj3mDD4NvAccB8YCfwP4erQociIreISKOINA7lfiUJx6bdi1PhebTnNAyUUmpQYWCMedcY4xpjPOD/4ncDAewAppUs2hCU9VfeDEwQkchB5f1t9zvGmIXGmIV1dXWDqToQ3MaaBBXFgrYMlFKKQYaBiEwueXkF0HOm0QrgahGJichMYDbwR+BlYHZw5lAUf5B5hfHvTfs88LHg89cBvxhMnY5Ez5hBZbFAZ6GTolc8/IeUUmocO+xdS0Xkh8B5wEQRaQK+BJwnIvMBA2wDPgVgjFkvIo8CG4Ai8FljjBus53PAM4ANPGCMWR9s4k7gRyLyVeBPwPeGa+f6k4hG6DQJKgs5IE5HvoPq+HsnulZKqbA4bBgYY67po7jfL2xjzD3APX2UPwk82Uf5Vnq7mUZFMmqzmwQNhQwQpz3frmGglAq1UF6BnIjadJo4VcUCgA4iK6VCL5RhkHT8qS8rg9tY6yCyUirswhkG0QhdRsNAKaV6hDIMEj1nEwVh0JZrK3ONlFKqvEIZBv6ppdoyUEqpHqENgy6TIGYgbjk6gKyUCr1QhkFPNxFApR3XloFSKvRCGQZR2yIrSQAqraiGgVIq9EIZBiKC66QBqBRHw0ApFXqhDAMA4wQtA7H1bCKlVOiFNgzisSg5iVNpRFsGSqnQC20YJKIRMlaSSqO3o1BKqdCGQTJqkxH/WoPuYjcFr1DuKimlVNmEOgy6SVDp+nMZdOQ7ylwjpZQqn9CGQcKx6SBJZSEP6C0plFLhFtowSEZt2kySqnwG0FtSKKXCLbRhkIhGaPFSVGY7AR1EVkqFW2jDIBm12eclqcz6IaAtA6VUmIU6DJrdBJXaTaSUUuENg0TUps2kqOq5jbV2EymlQiy0YZB0/DBwgIQdoy2vZxMppcIrvGEQjdBGCoCKSEJbBkqpUAttGPR0EwFU6ZwGSqmQC20YJKP2/pZBpeicBkqpcAttGJS2DCrF1jBQSoVaaMMgGY3Q3tMyMKJjBkqpUIuUuwLlkozauNgUIikqPUN7QcNAKRVeoW0ZJBwbgLxTSaXnkilmKLh6G2ulVDiFNgySUT8McpFKqnruXKrXGiilQirEYeD3kGXsCiqLOUBvSaGUCq/QhkHcsRCBbitNZa4b0FtSKKXCK7RhICIkHJsuK01lrgvQloFSKrxCGwbgjxt0SJrKjB8COtuZUiqsQh0GiahNO2kq80E3kbYMlFIhFeowSDoR2kySyp7bWGsYKKVCKtRhkIjatHj+bayTdlwHkJVSoRXuMHBsWk0SgMpIUlsGSqnQCncYRG2a3SAM7Ji2DJRSoXXYMBCRB0Rkt4isKymrEZFnRWRz8LM6KBcR+aaIbBGRNSKyoOQz1wXLbxaR60rKzxCRtcFnvikiMtw72Z+4Y/WGgTjaMlBKhdZAWgYPAhcfVPZF4DljzGzgueA1wCXA7OBxC/Bt8MMD+BJwFrAI+FJPgATL3FzyuYO3NWLiEZu9xQQAlRLRMFBKhdZhw8AY8wKw76Diy4CHgucPAZeXlD9sfC8BE0RkMnAR8KwxZp8xpgV4Frg4eK/SGPOSMcYAD5esa8TFozZ7i3EA0ga6Cl2jtWmllBpTBjtmUG+M2Rk83wXUB8+nAttLlmsKyg5V3tRHeZ9E5BYRaRSRxj179gyy6r3iEZvOgkA0Tdrz6Mx3DnmdSil1NBryAHLwF70ZhroMZFvfMcYsNMYsrKurG/L6ElGLbNGD+ARSrktnoRN/d5RSKlwGGwbvBl08BD93B+U7gGklyzUEZYcqb+ijfFTEIzauZzDxKircPAZDppgZrc0rpdSYMdgwWAH0nBF0HfCLkvLlwVlFi4G2oDvpGWCpiFQHA8dLgWeC99pFZHFwFtHyknWNuEQwp4EbqyJV8G9j3VnQriKlVPgcdtpLEfkhcB4wUUSa8M8K+kfgURG5EXgLuCpY/EngUmAL0A3cAGCM2SciXwFeDpa72xjTMyj9GfwzlhLAU8FjVMSc3jBId+2BOHTmO5mUnDRaVVBKqTHhsGFgjLmmn7cu6GNZA3y2n/U8ADzQR3kjcMrh6jESeqa+LDhVpHPdELe1ZaCUCqVQX4Ecd/zdz0crSec6AO0mUkqFU6jDoKdlkLMrSOX9gWO91kApFUahDoN4TxhEKkgb/zbWeq2BUiqMNAyAbruCtOdfX6DdREqpMAp5GPi7321XkgwmuNEwUEqFUajDoGfMoEvSOEDCitKV1zEDpVT4hDoMerqJOiUNQMqKastAKRVKoQ6DnpZBm6QASFsRDQOlVCiFOgx6WgYdxg+DFHrRmVIqnEIdBrGIv/tdruXfxhrRMQOlVCiFOgwsS4hFLHIFF+ITSHtGWwZKqVAKdRiAf+fSTMGFxARSnqthoJQKpdCHQTxik+1pGbhF7SZSSoWShoFjkSl4kKwmXcjqbGdKqVDSMHCClkHFFNKZDp3tTCkVShoGPWFQNZVUwQ+BjnxHmWullFKjK/RhkOgJg8qppIP7E+ltrJVSYRP6MIg7FtmCB1UN+8NAzyhSSoVN6MNg/6mllVP1NtZKqdAKfRjsP7W0YjKp4CQineBGKRU2GgbRIAzsCOnkREDHDJRS4aNhELH9MQMgnZ4MaDeRUip8Qh8GiajljxkAqcoGQMNAKRU+oQ+DeMTG9QwF1yNS1UDCMzpmoJQKndCHQSLqz2mQKbhQ1UDK8+jKNJe5VkopNbpCHwaxYIKb0gvPOrv3lrlWSik1ukIfBj1TX2bzHlRNJW08OrMtZa6VUkqNrtCHQdzxfwXZogtV00h5hi69N5FSKmRCHwY9LYNM3oXkRNJGzyZSSoVP6MMgXjpmYFmk7Tidbq7MtVJKqdGlYeCUnE0EpKNpurxCOauklFKjTsOgZ8wguAo5FauiE6OznSmlQiX0YZAo7SYC0vEajEC3DiIrpUIk9GEQPzgMUnUAdLZuK1eVlFJq1GkYHDxmENysrqvlzbLVSSmlRlvow6C3mygYM6iYCkBn21tlq5NSSo220IdBLNIzgBy0DKqOBaCzfUfZ6qSUUqMt9GFgWUIsYvWGQfoYADq7dpWzWkopNaqGFAYisk1E1orIahFpDMpqRORZEdkc/KwOykVEvikiW0RkjYgsKFnPdcHym0XkuqHt0pGLO3ZvGEQrAOjq2j3a1VBKqbIZjpbB+caY+caYhcHrLwLPGWNmA88FrwEuAWYHj1uAb4MfHsCXgLOARcCXegJktCQcu3eCGycFoDerU0qFykh0E10GPBQ8fwi4vKT8YeN7CZggIpOBi4BnjTH7jDEtwLPAxSNQr37FHat3ALknDLr3QiEzmtVQSqmyGWoYGODXIrJKRG4JyuqNMTuD57uA+uD5VGB7yWebgrL+yt9DRG4RkUYRadyzZ88Qq94rXtIyiFgRElaUTlOAjb8ctm0opdRYNtQw+DNjzAL8LqDPisg5pW8a/54Ow3ZfB2PMd4wxC40xC+vq6oZrtQeMGQCkY1V0xSvhlYeHbRtKKTWWDSkMjDE7gp+7gZ/h9/m/G3T/EPzsGYndAUwr+XhDUNZf+ahJHBQGKSdFZ/V02PZb2Ld1NKuilFJlMegwEJGUiFT0PAeWAuuAFUDPGUHXAb8Inq8AlgdnFS0G2oLupGeApSJSHQwcLw3KRk3pmAFA2knTmZ4IYsGfHhnNqiilVFlEhvDZeuBnItKznh8YY54WkZeBR0XkRuAt4Kpg+SeBS4EtQDdwA4AxZp+IfAV4OVjubmPMviHU64glor1jBgCpaIrOYhaO/yCs/gGc/zdg2aNZJaWUGlWDDgNjzFbgtD7Km4EL+ig3wGf7WdcDwAODrctQxSMHdhNVOBU0Z5rh9Bvg0U/AludgztJyVU8ppUZc6K9ABohHDwqDaAUt2RaYczEkJ8KfdCBZKTW+aRjQ0zLoHTOYUz2H5mwzu/OtcNrV8PpT0KG3p1BKjV8aBkAiah0wZjCvbh4Aa/asgYWfBK8Iqx4sU+2UUmrkaRjgtwxcz1Bw/dbB3Jq5OJbjh0HtcXDcBX4YuDo3slJqfNIwwD+bCHonuInaUU6qPYlX97zqL7DoZujYCa/9qlxVVEqpEaVhAMQOmvoS/K6i9c3rKXgFmL0Uqo6Fl79brioqpdSI0jCgZLazfO8g8ry6eeTcHJtaNvnXGJz5Sf+K5N0by1VNpZQaMRoG+FcgA2SLvS2D+XXzgWAQGeD05WDHtHWglBqXNAzwB5DhwG6i+mQ9kxKTescNUrVwypXw6o8g21aOaiql1IjRMKBkADnfGwYiwry6eb0tA4Cz/gvkO/VupkqpcUfDgNJuIu+A8tPqTmN7x3b2ZYNbJU2ZDzOWwEv362mmSqlxRcMAfz4DOLBlAAddfNbj7M9BexNs+AVKKTVeaBjQGwa54oFhcGLtiUQkcmAYzF4KtbPh9/8LzLDN26OUUmWlYUDvqaUHtwwSkQRzaubwp91/wjNBF5JlwdmfgZ2r4a3fj3JNlVJqZGgY0NsyKD2bqMfpk06n8d1GznrkLD624mPc13gf3rxlkKyFP/zv0a6qUkqNiKFMbjNu7G8ZFLz3vPfp0z7NrKpZbGvfxvq96/n++u+zpGEJZ555E/znvfDG83Dc+aNdZaWUGlYaBkAsEpxN1EfLoCpWxVUn+JO1ZYoZzvvxeTyx9QnOfN8XYOMv4bHr4ZbnoWbWaFZZKaWGlXYTAZYlxCJWn2FQKhFJ8MHpH+TX235N1o7A1cH8yD/8OOQ6RqGmSik1MjQMAnHHPmwYAHzkuI/QWejkP5r+w28N/MWDsPd1+Nl/Ae+93UxKKXU00DAIJBz7gAlu+nNm/ZlMSk7iiTee8AuOOx+W3gOvPQG//foI11IppUaGhkEg7lgHTH3ZH9uy+dCsD/G7Hb/rvTJ58adh3tXw/N/D60+PcE2VUmr4aRgE4gNsGQB8ZNZHKJoiT735lF8gAh/5F5h8Gjx+M+zdPHIVVUqpEaBhEBjomAHA7OrZzK2Z29tVBOAkYNm/gR2FH14DW56DQnaEaquUUsNLwyCQOIIwAPjz4/6cdc3rePT1R3sLJ0yDqx7yp8j8tyvh3pnwyFXw+/8N7/wJ3OII1FwppYZOrzMIxB2LXe0DvxPp1SdczcqdK/nKS1/BEouPzfmY/8aMP4P/thm2vQibfw1b/h02P+O/F6uED/8znPqxEdgDpZQaPG0ZBOZPq2bjznZ+v2XvgJZ3bIf7zruPJVOX8OU/fJnHNj3We/+iaBLmLIUPfR1uXw3/dSN89Hsw6UT42adg87MjtyNKKTUIYo7SO28uXLjQNDY2Dtv6sgWXi//lBUSEp25fsv9+RYeTc3Pc/vzt/G7H74jZMaZXTue4quP4wLEf4Lxp5xGPxEs20g4Pfgiat8DyFTDtzGGrv1JKDYSIrDLGLHxPuYZBrxc37+Uvv7eS2z5wPP916QkD/lzOzfHUm0+xpWULb7a/ycbmjezJ7CHlpFg6fSm3nn4rdck6f+HO3fC9pZBthTNvhkjUH3SOVUJqIiQnQt0JkKwZ1n1TSinQMBiwz/94NU+seYcnb1vC7PqKQa3D9Vwa323kl2/8kme2PUNtopb/c+H/YXrldH+BfW/Cv30U9r3RzxoEjjkVZp0LJ1wKx57tn76qlFJDpGEwQM2dOS647z+ZVp3ke9cvZFJF/PAfOoT1e9fz6X//NCLCv37wXzm59uTeN40BrwhuHrJt0LUXunbDjldg639C0x/99ybOgQXXwWnXQKp2iHuolAozDYMj8OyGd7n1h69QEXf4X9eczuJZQ/sC3ta2jU89+ylac628b8r7qIpVUR2v5pKZlzCnek7/H8x1+tNrrnrQDwYrAsdfCPOu8s9a6gmSeBUkqodUR6VUOGgYHKHXdrXzmUdeYdveLm67YDY3L5lFKjb4M3F3d+/mqy99lbfb36Y110prrhVBuP6U6/nUvE8dONDcl3c3wKs/hLWP+dcxlLKjcPon4M/+CiYcO+g6KqXGPw2DQejMFfmbx9ey4tV3qEo4LD97Ote9bwYT07Ehr7s128rXG7/OL974BdMqpnHb6bdxwbEX4NjOoT/ouf41DHteB9uBSAy2r4Q/PQIYmHmO3+XU/g64BTjpz2H+tTD1DB13UEppGAzFK2+3cP9/vMGvN7yLYwvvP34il5xyDBeedAw1qeiQ1r1y50q++tJX2da+jZp4DZcffzl1iTo2t25mS+sWqmPVXDrzUs4/9nwSkUT/K2prgt99ww+KdD1UToViBl570v9ZMwsazvQHpqtnQsubsHsjtLwFDQv90JiywF9X1x5o2QZiQazCf6Tq/PBRSh3VNAyGwZbdnfz45bd5at0umloyiMAJ9RUsmlnDwhk1nDylkhm1KWzryP4Cdz2X37/zex7b9BgvNL2Aa1yqY9UcX308b7e/zbvd75KMJDmj/gympKdwTOoYpldOZ97EedSn6gEwxrAvu4+CV+CY1DG9K8+2w4afw8YnYNda6Hin972e0Ni1xh9/SE2CQgbyfUzUIxZUTPG7oepPgmmL4dizoGqatjiUOopoGAwjYwzr32nnN6/t5uVt+1j1Vgvdef++RnHHYk59BTMnpphek+TY2hTTqhNMrU5wTGWciH3oi773ZffhGY+JiYkAeMZj1bur+NXWX7G+eT07u3bSlmvbv3x9sp76ZD3b2rfRnm8HYMGkBfzFCX/BhdMvJGYf1KXV1ez/1V89A1K1GGOQTAtsegbe+I0/EF17nN96EIFce9DttBPatvstiV1rIN/pr8+OBQPYE/wuq55/TrbTWy42ZPZBd7P/fOYSOO4CmDwPOnb562zf4V97kWn1u7caFsKs8/xrL5RSw0bDYAQVXI/Xd3WwcWc7r+3q4PVdHWxr7uKd1gxeya/XtoSJ6Sh1FTHq0jFq0zFqUlGqk1Gqkw4Tkg5ViSgTkg4V8QiVCYd0NIJ1UEuju9DNG61v8OqeV1mzZw37svuYUTWDmVUzyRQzPL75cbZ3bCcRSdBQ0cDU9FTqEnW059tpybbQkmuhPddOR76DvJvn3Gnn8vG5H+fMY85ERGjONLOpZRMAKSdFMpIEoGiKFL0icXGo6dxN1a4N2O07/LDItPpnNgEgUMz6QZJp9VsdyVr/ke+E7X8Er5/7QEXifiuk0O2/rjvRvwDPdiCS8C/Im3qGHyTdzbBnE+zdBJbtB1miunf7xRxUHANTFxzYgnGDbR/c7dXzf0FbOmoc0zAog3zRo6mlmx2tGXa0ZNjRmmF3e47dHVl2d+Ro6crT3JUnVzz0pDqpqE06HiEVjZCI2vt/JqM2CccmEbWJOzZxxyIesYlFhHeL69nW/TJthV20FnbRWWwlFamgMlpNZbSKymgVFdE0BpcX3nmW9nwbDeljyblZ9mR2D2j/LLGYnJrMiTUnMrdmLg0VDSQjSZJOkmQk6QeJkyTtpEk5KaTnSzbX0TsIXtUAE6ZD1VRI1IAT9wfJ31kNW5/3g6PQ7QdNrtP/4j84SKwIGM9/9PtLrIP4BH88JNvqlyUn+mFhR3uv8QD/HlL1J/utJ7fgd50Zz2/lJGshmoZMS7CuNn89lZODbrZuf/25Dv+q8nS9v20Rfz3FHMQr/e2m6/26F7r996yI/xlb7x+pRs6YDwMRuRj4BmAD3zXG/OOhlj8awmCguvNFWrsLwSNPe7ZAe6ZIW6ZAR65IV65IZ7ZIV75IJu/SlS/SnXfJ5F268y7ZQvAoerjeII6nFIhUrsapfBXjpnGzU/Gyk7Ekgm3niEQKWJZgi40tDpadx7K7wO7Ei+ymYG+naO85zEYsIiRxSGJLDFtiOBLHkSRRSRO1kkTEwbYiWGLjkaVosngUqIjUUe1MZYIzmQQwIfMWE7q3sdeJsCuaYI/lUTB5xGQQN0NNpIapyeOZmpzNxHwbFfteJdm2lpzbRWu0gtZoEhGYmO9kUraNKq+IF6+jkKgFz6OjfQO7ut4kU+xkVr7ADM/CQbDcA+enMGLhOWnsoHtuuJhoGpwU2BGM5SDGg0IX5Lv8BZITkdREvxXkJPzuOcvxA7OY8396RT9UMb3XocQq/c97Rb88UQPpSX4oieV/zs37t0xp3+E/nKTfqqpqgGgqWG/RDy4n6QejZfnbLWb9elRO8ceikrVBK0v87Xlu7+e9oh+wxVzQsmzx969yst9FGUsf4hdkgs/s87spk7X+HxH9Keb89Xfv8+tdPcO/DcwRHxgzLlqN/YXBmPgTRERs4FvAhUAT8LKIrDDGbChvzUZHMhohGY0wZcIhzhYaoILrkSt6+wMiX/Rf54oe+aJHwfX2l/U8L3oeefd0/7nrUfTM/vKiayi4xn/uGf991/jPg/ddz5B1u8l6LRSN/yWe9zK4ZHFNliLduHTjSgaXbgqSx0gOjw6MtRusDEgGIy4ifpgZY4EXwxgbK9LZz84CQW+S8SJgLMBC7Cy0sH89Ip5/f14Lf0wjV7KOaM9y70BnBMRFYkWIxYH4/nWYQjViwBIXwcPzEhTdFMZNYBmbuHGJUyRvFyjYWdxIBgtDxEDEgGVsxItivChxT0h5HhO8IgLkRMiJYGGIG5ckRRKeR9wzJFwwRui06uhM2rjikTZ50l0ZUl1t2Lg4uDh4xD2bhGsT8Sw6LZtOW8haBoftxMgTJ0/UQMxA3IMqMsSNwTGGIkJehJz4vybPROmigrgpUEM7ETw8oIhQFP/r3TYQwf981Bhixv/1tlsWHZaFK1DjutS6HmnPIydC1hIyInRZFt3BfseNIWkMMc+QtYQusWi2kyQ8mOAVqHYLxAzY2ESMBWTIWNBp+Z+d4HrEjUO3xGm2HVpsGxuPWjfHxGIGmyK7bZs9EZu8CGkXXJmIJxOoMEUqvRyO8fevIDauGGxTxDE5IiZPxBSImByCR4ddTVukji67lgiCY1xsUyRvxclaKbJ2Est4OCaPY/JETBGbIrYpUpQoWStJzk5hEMRk8EwGY4pYeFi42ESwJU3RTpK3YhTEotUq4hqPerdApZsl7nVzwq0/xYkM79f3mAgDYBGwxRizFUBEfgRcBoQiDIaTY1s4tkV6CBfIlZPrubjGxbEcRATPM3QVunm7fTtvd2yn4BbxMBjPUBWroS5RT02sDlscPOMHU2uulc2tG9nc+hoZN0PUiuJYMaJWglSkgmSkAtcY2vMttOdb6Sp0UvDy5NwsFhEmJRqYFG8gZqXYmXmLHd1b2Zv1L/SzsDF4ZIrdZNwOMm4HRa9A0RTwjEulXUE6UkfSnoAguMZ/r+DlyXuZ4NHNu143270MBuO3iMSBYHnXFDD03eXlR8bB79nBA/yvYzd4lIqyP/n2O8Rf3/vFgLoBLDeaJgxwuThQdYj3e0/EiBj/N2ak9HcbCx4H6woeEPUgasBgcAU8gqAErCAcjUhQbpDgZ16EQj+tDDGGCs9gYWi1D7x7coVrqHHhwc49TJww+ZB7f6TGyjfGVGB7yesm4KyDFxKRW4BbAI49Vq+0HY9sy8am9z+AZQkVsRQn183l5Lq5A1rHNJKcOnkKcMEw1OjsYVhH33q6aKWPL4WCW6C72E2mmMEzHiknRcpJEbEieMaj6PmD+SKCIBS9Im35NtpybXQXuklH01RGK0lH09hiIwgGQ7aYJVPMkClmyHt5Cm6BglcgYkWI2lGiVhTPeBS8Ank3j2v8UHGNG3QT2kSsyP6yolek4BbIuTlynt/kqoxWUhWtwhKLfdl9NGeb6ch3kIgkiNtxYpGYvz+RFLFIjFwxR3exm5ybI27HSTkp4pE43YVuOvIdtOfbybpZ8m6evJsnFolREa0g7aTJFXO05FpoybbgWA7V8Wqq49UYY2jLt9GabcW2bOoSdUxKTsKxHdpz7bTn/Ud3oZuuQhc5N4djOTiWQ8SK7P+9CoKIYIl/FqBnPIwxFE2RvJvf//u0xCJiRbAtG4x/soXrub3rCY5xzzF3bIdUxD+mUbs3pLPF7P7jaIyhLllHXXwiUTvKu5nd7Oraxd7MXmqrSk4fHyZjJQwGxBjzHeA74I8ZlLk6Sg1JXyHQw7EdquwqqmLv/cvWEsv/4rYP/Es/HU0zNT31kNtMOanBVVaNe2NlprMdwLSS1w1BmVJKqVEwVsLgZWC2iMwUkShwNbCizHVSSqnQGBPdRMaYooh8DngGf+zlAWPM+jJXSymlQmNMhAGAMeZJ4Mly10MppcJorHQTKaWUKiMNA6WUUhoGSimlNAyUUkoxhm5Ud6REZA/w1iA/PhHYO4zVORqEcZ8hnPsdxn2GcO73YPZ5ujHmPfcYOWrDYChEpLGvu/aNZ2HcZwjnfodxnyGc+z2c+6zdREoppTQMlFJKhTcMvlPuCpRBGPcZwrnfYdxnCOd+D9s+h3LMQCml1IHC2jJQSilVQsNAKaVUuMJARC4WkddFZIuIfLHc9RkpIjJNRJ4XkQ0isl5Ebg/Ka0TkWRHZHPysLnddh5uI2CLyJxF5Ing9U0RWBsf8x8Et0scVEZkgIj8RkddEZKOInD3ej7WIfD74t71ORH4oIvHxeKxF5AER2S0i60rK+jy24vtmsP9rRGTBkWwrNGEgIjbwLeAS4CTgGhE5qby1GjFF4K+NMScBi4HPBvv6ReA5Y8xs4Lng9XhzO7Cx5PXXgH82xhwPtAA3lqVWI+sbwNPGmLnAafj7P26PtYhMBW4DFhpjTsG/7f3VjM9j/SBw8UFl/R3bS4DZweMW4NtHsqHQhAGwCNhijNlqjMkDPwIuK3OdRoQxZqcx5pXgeQf+l8NU/P19KFjsIeDyslRwhIhIA/Ah4LvBawE+APwkWGQ87nMVcA7wPQBjTN4Y08o4P9b4t99PiEgESAI7GYfH2hjzArDvoOL+ju1lwMPG9xIwQUQmD3RbYQqDqcD2ktdNQdm4JiIzgNOBlUC9MWZn8NYuoL5c9Roh/wLcAXjB61qg1RhTDF6Px2M+E9gDfD/oHvuuiKQYx8faGLMD+DrwNn4ItAGrGP/Hukd/x3ZI33FhCoPQEZE08FPgr4wx7aXvGf+c4nFzXrGIfBjYbYxZVe66jLIIsAD4tjHmdKCLg7qExuGxrsb/K3gmMAVI8d6ulFAYzmMbpjDYAUwred0QlI1LIuLgB8EjxpjHg+J3e5qNwc/d5arfCHg/8Ocisg2/C/AD+H3pE4KuBBifx7wJaDLGrAxe/wQ/HMbzsf4g8KYxZo8xpgA8jn/8x/ux7tHfsR3Sd1yYwuBlYHZwxkEUf8BpRZnrNCKCvvLvARuNMfeVvLUCuC54fh3wi9Gu20gxxvx3Y0yDMWYG/rH9jTHmWuB54GPBYuNqnwGMMbuA7SJyQlB0AbCBcXys8buHFotIMvi33rPP4/pYl+jv2K4AlgdnFS0G2kq6kw7PGBOaB3ApsAl4A/jbctdnBPfzz/CbjmuA1cHjUvw+9OeAzcC/AzXlrusI7f95wBPB81nAH4EtwGNArNz1G4H9nQ80Bsf750D1eD/WwJeB14B1wP8DYuPxWAM/xB8XKeC3Am/s79gCgn/G5BvAWvyzrQa8Lb0dhVJKqVB1EymllOqHhoFSSikNA6WUUhoGSiml0DBQSimFhoFSSik0DJRSSgH/H/LYfhiehs7XAAAAAElFTkSuQmCC\n", "text/plain": ["
"]}, "metadata": {"needs_background": "light"}, "output_type": "display_data"}], "source": ["pandas.DataFrame(dict(skl_loss=nn.loss_curve_, \n", " ort_loss=train_session.train_losses_,\n", " ort_loss2=train_session2.train_losses_)).plot();"]}, {"cell_type": "code", "execution_count": 47, "id": "331120b3", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["2.26 s \u00b1 0 ns per loop (mean \u00b1 std. dev. of 1 run, 1 loop each)\n"]}], "source": ["%timeit -n 1 -r 1 train_session2.fit(X_train, y_train)"]}, {"cell_type": "markdown", "id": "0bdaba6b", "metadata": {}, "source": ["## Profiling"]}, {"cell_type": "code", "execution_count": 48, "id": "1505d30b", "metadata": {"scrolled": false}, "outputs": [{"name": "stdout", "output_type": "stream", "text": [" -- 1 1 -- 0.00001 3.78074 -- :18: ()\n", " fit -- 1 1 -- 0.00181 3.78073 -- onnxcustom/onnxcustom/training/optimizers_partial.py:263:fit (fit)\n", " __init__ -- 1 1 -- 0.00002 0.00003 -- onnxcustom/onnxcustom/training/data_loader.py:26:__init__ (__init__)\n", " get_ort_device -- 1 1 -- 0.00000 0.00000 -- onnxruntime_helper.py:55:get_ort_device (get_ort_device)\n", " numpy_to_ort_value -- 2 2 -- 0.00000 0.00001 -- onnxruntime_helper.py:120:numpy_to_ort_value (numpy_to_ort_value) +++\n", " needs_grad -- 3 3 -- 0.00001 0.00001 -- onnxcustom/onnxcustom/training/optimizers_partial.py:99:needs_grad (needs_grad)\n", " needs_grad -- 3 3 -- 0.00000 0.00000 -- onnxcustom/onnxcustom/training/sgd_learning_rate.py:299:needs_grad (needs_grad)\n", " get_full_state -- 101 101 -- 0.00020 0.00093 -- onnxcustom/onnxcustom/training/optimizers_partial.py:147:get_full_state (get_full_state) +++\n", " set_state -- 4 4 -- 0.00008 0.00026 -- onnxcustom/onnxcustom/training/optimizers_partial.py:196:set_state (set_state)\n", " _get_att_state -- 4 4 -- 0.00000 0.00000 -- onnxcustom/onnxcustom/training/optimizers_partial.py:139:_get_att_state (_get_att_state) +++\n", " numpy_to_ort_value -- 24 24 -- 0.00002 0.00011 -- onnxruntime_helper.py:120:numpy_to_ort_value (numpy_to_ort_value) +++\n", " -- 12 12 -- 0.00002 0.00002 -- ~:0: ()\n", " -- 56 56 -- 0.00001 0.00001 -- ~:0: () +++\n", " -- 24 24 -- 0.00000 0.00000 -- ~:0: () +++\n", " -- 1 1 -- 0.00001 0.00095 -- onnxcustom/onnxcustom/training/optimizers_partial.py:311: ()\n", " get_initializer -- 7 7 -- 0.00004 0.00094 -- onnxcustom/onnxcustom/training/ortgradient.py:269:get_initializer (get_initializer) +++\n", " -- 1 1 -- 0.00001 0.00083 -- onnxcustom/onnxcustom/training/optimizers_partial.py:315: ()\n", " get_initializer -- 7 7 -- 0.00004 0.00082 -- onnxcustom/onnxcustom/training/ortgradient.py:269:get_initializer (get_initializer) +++\n", " _iteration -- 100 100 -- 0.41903 3.74610 -- onnxcustom/onnxcustom/training/optimizers_partial.py:397:_iteration (_iteration)\n", " iter_ortvalue -- 6800 6800 -- 0.02838 0.14761 -- onnxcustom/onnxcustom/training/data_loader.py:139:iter_ortvalue (iter_ortvalue)\n", " _next_iter -- 6700 6700 -- 0.00946 0.07207 -- onnxcustom/onnxcustom/training/data_loader.py:93:_next_iter (_next_iter)\n", " -- 6700 6700 -- 0.00245 0.00423 -- ~:0: () +++\n", " -- 6700 6700 -- 0.05838 0.05838 -- ~:0: ()\n", " numpy_to_ort_value -- 13400 13400 -- 0.00658 0.03860 -- onnxruntime_helper.py:120:numpy_to_ort_value (numpy_to_ort_value) +++\n", " -- 6900 6900 -- 0.00467 0.00855 -- ~:0: () +++\n", " forward -- 6700 6700 -- 0.31685 0.44643 -- onnxcustom/onnxcustom/training/ortgradient.py:623:forward (forward)\n", " input_to_ort -- 6700 6700 -- 0.08002 0.11492 -- onnxcustom/onnxcustom/training/ortgradient.py:552:input_to_ort (input_to_ort) +++\n", " save_for_backward -- 6700 6700 -- 0.01032 0.01032 -- onnxcustom/onnxcustom/training/ortgradient.py:604:save_for_backward (save_for_backward)\n", " -- 6700 6700 -- 0.00434 0.00434 -- ~:0: () +++\n", " backward -- 6700 6700 -- 0.43012 0.48957 -- onnxcustom/onnxcustom/training/ortgradient.py:702:backward (backward)\n", " input_to_ort -- 6700 6700 -- 0.04148 0.05262 -- onnxcustom/onnxcustom/training/ortgradient.py:552:input_to_ort (input_to_ort) +++\n", " saved_tensors -- 6700 6700 -- 0.00207 0.00207 -- onnxcustom/onnxcustom/training/ortgradient.py:613:saved_tensors (saved_tensors)\n", " -- 6700 6700 -- 0.00476 0.00476 -- ~:0: ()\n", " loss_gradient -- 6700 6700 -- 0.05841 0.26967 -- onnxcustom/onnxcustom/training/sgd_learning_loss.py:53:loss_gradient (loss_gradient)\n", " clear_binding_inputs -- 6700 6700 -- 0.00545 0.01270 -- onnxcustom/onnxcustom/training/_base_onnx_function.py:130:clear_binding_inputs (clear_binding_inputs)\n", " _cache_in_clear -- 6700 6700 -- 0.00568 0.00725 -- onnxcustom/onnxcustom/training/_base_onnx_function.py:119:_cache_in_clear (_cache_in_clear)\n", " -- 6700 6700 -- 0.00157 0.00157 -- ~:0: () +++\n", " _bind_input_ortvalue -- 13400 13400 -- 0.02070 0.07545 -- onnxcustom/onnxcustom/training/_base_onnx_function.py:159:_bind_input_ortvalue (_bind_input_ortvalue) +++\n", " _call_iobinding -- 6700 6700 -- 0.11997 0.11997 -- onnxcustom/onnxcustom/training/sgd_learning_loss.py:50:_call_iobinding (_call_iobinding)\n", " -- 13400 13400 -- 0.00315 0.00315 -- ~:0: () +++\n", " penalty_loss -- 6700 6700 -- 0.00112 0.00112 -- onnxcustom/onnxcustom/training/sgd_learning_penalty.py:84:penalty_loss (penalty_loss)\n", " update_weights -- 40200 40200 -- 0.00651 0.00651 -- onnxcustom/onnxcustom/training/sgd_learning_penalty.py:95:update_weights (update_weights)\n", " update_weights -- 40200 40200 -- 0.40487 1.94238 -- onnxcustom/onnxcustom/training/sgd_learning_rate.py:345:update_weights (update_weights)\n", " _bind_input_ortvalue -- 201000 201000 -- 0.19630 0.51693 -- onnxcustom/onnxcustom/training/_base_onnx_function.py:159:_bind_input_ortvalue (_bind_input_ortvalue) +++\n", " _bind_output_ortvalue -- 80400 80400 -- 0.07458 0.18952 -- onnxcustom/onnxcustom/training/_base_onnx_function.py:202:_bind_output_ortvalue (_bind_output_ortvalue)\n", " _bio_cache -- 80400 80400 -- 0.04417 0.05406 -- onnxcustom/onnxcustom/training/_base_onnx_function.py:138:_bio_cache (_bio_cache) +++\n", " _bio_ptr -- 80400 80400 -- 0.05222 0.05222 -- onnxcustom/onnxcustom/training/_base_onnx_function.py:155:_bio_ptr (_bio_ptr) +++\n", " _bio_do_bind_out -- 12 12 -- 0.00003 0.00003 -- onnxcustom/onnxcustom/training/_base_onnx_function.py:198:_bio_do_bind_out (_bio_do_bind_out)\n", " -- 80400 80400 -- 0.00863 0.00863 -- ~:0: () +++\n", " _call_iobinding -- 40200 40200 -- 0.63987 0.63987 -- onnxcustom/onnxcustom/training/sgd_learning_rate.py:28:_call_iobinding (_call_iobinding)\n", " value -- 40200 40200 -- 0.00953 0.00953 -- onnxcustom/onnxcustom/training/sgd_learning_rate.py:160:value (value) +++\n", " -- 80400 80400 -- 0.16512 0.16512 -- ~:0: () +++\n", " -- 80400 80400 -- 0.01655 0.01655 -- ~:0: () +++\n", " -- 100 100 -- 0.00026 0.00426 -- ~:0: ()\n", " _mean -- 100 100 -- 0.00163 0.00400 -- site-packages/numpy/core/_methods.py:162:_mean (_mean)\n", " _count_reduce_items -- 100 100 -- 0.00097 0.00107 -- site-packages/numpy/core/_methods.py:66:_count_reduce_items (_count_reduce_items)\n", " -- 200 200 -- 0.00010 0.00010 -- ~:0: ()\n", " -- 100 100 -- 0.00004 0.00004 -- ~:0: ()\n", " -- 100 100 -- 0.00109 0.00109 -- ~:0: ()\n", " -- 100 100 -- 0.00006 0.00006 -- ~:0: () +++\n", " -- 100 100 -- 0.00004 0.00004 -- ~:0: () +++\n", " -- 200 200 -- 0.00007 0.00007 -- ~:0: ()\n", " -- 100 100 -- 0.00358 0.00358 -- ~:0: ()\n", " -- 6700 6700 -- 0.00169 0.00169 -- ~:0: () +++\n", " -- 40300 40300 -- 0.01424 0.01424 -- ~:0: () +++\n", " _create_training_session -- 1 1 -- 0.00001 0.02824 -- onnxcustom/onnxcustom/training/optimizers_partial.py:626:_create_training_session (_create_training_session)\n", " __init__ -- 1 1 -- 0.00008 0.02820 -- onnxcustom/onnxcustom/training/ortgradient.py:54:__init__ (__init__)\n", " -- 1 1 -- 0.00001 0.00001 -- onnxcustom/onnxcustom/training/ortgradient.py:91: ()\n", " -- 1 1 -- 0.00000 0.00000 -- onnxcustom/onnxcustom/training/ortgradient.py:94: ()\n", " -- 1 1 -- 0.00000 0.00000 -- onnxcustom/onnxcustom/training/ortgradient.py:113: ()\n", " _init_next -- 1 1 -- 0.00010 0.02809 -- onnxcustom/onnxcustom/training/ortgradient.py:163:_init_next (_init_next)\n", " -- 1 1 -- 0.00001 0.00001 -- onnxcustom/onnxcustom/training/ortgradient.py:173: ()\n", " -- 1 1 -- 0.00000 0.00000 -- onnxcustom/onnxcustom/training/ortgradient.py:175: ()\n", " -- 1 1 -- 0.00000 0.00000 -- onnxcustom/onnxcustom/training/ortgradient.py:178: ()\n", " _create_onnx_graphs -- 1 1 -- 0.00662 0.02797 -- onnxcustom/onnxcustom/training/ortgradient.py:287:_create_onnx_graphs (_create_onnx_graphs)\n", " -- 1 1 -- 0.00001 0.00001 -- onnxcustom/onnxcustom/training/ortgradient.py:396: ()\n", " -- 1 1 -- 0.00001 0.00001 -- onnxcustom/onnxcustom/training/ortgradient.py:397: ()\n", " -- 1 1 -- 0.00001 0.00002 -- onnxcustom/onnxcustom/training/ortgradient.py:399: ()\n", " _provider_name_to_device_type -- 1 1 -- 0.00000 0.00000 -- onnxcustom/onnxcustom/training/ortgradient.py:260:_provider_name_to_device_type (_provider_name_to_device_type) +++\n", " -- 1 1 -- 0.00002 0.00002 -- onnxcustom/onnxcustom/training/ortgradient.py:404: ()\n", " _provider_name_to_device_type -- 7 7 -- 0.00000 0.00000 -- onnxcustom/onnxcustom/training/ortgradient.py:260:_provider_name_to_device_type (_provider_name_to_device_type) +++\n", " -- 1 1 -- 0.00000 0.00000 -- onnxcustom/onnxcustom/training/ortgradient.py:410: ()\n", " _provider_name_to_device_type -- 1 1 -- 0.00000 0.00000 -- onnxcustom/onnxcustom/training/ortgradient.py:260:_provider_name_to_device_type (_provider_name_to_device_type) +++\n", " -- 1 1 -- 0.00001 0.00001 -- onnxcustom/onnxcustom/training/ortgradient.py:479: ()\n", " -- 1 1 -- 0.00001 0.00001 -- onnxcustom/onnxcustom/training/ortgradient.py:480: ()\n", " get_inputs -- 1 1 -- 0.00000 0.00000 -- onnxruntime/build/Windows/Release/Release/onnxruntime/capi/onnxruntime_inference_collection.py:111:get_inputs (get_inputs)\n", " get_outputs -- 1 1 -- 0.00000 0.00000 -- onnxruntime/build/Windows/Release/Release/onnxruntime/capi/onnxruntime_inference_collection.py:115:get_outputs (get_outputs)\n", " __init__ -- 2 2 -- 0.00004 0.02063 -- onnxruntime/build/Windows/Release/Release/onnxruntime/capi/onnxruntime_inference_collection.py:283:__init__ (__init__)\n", " get -- 2 2 -- 0.00001 0.00004 -- C:/Python395_x64/lib/_collections_abc.py:759:get (get)\n", " __getitem__ -- 2 2 -- 0.00001 0.00003 -- C:/Python395_x64/lib/os.py:674:__getitem__ (__getitem__)\n", " encodekey -- 2 2 -- 0.00001 0.00002 -- C:/Python395_x64/lib/os.py:746:encodekey (encodekey)\n", " check_str -- 2 2 -- 0.00000 0.00000 -- C:/Python395_x64/lib/os.py:740:check_str (check_str)\n", " __init__ -- 2 2 -- 0.00000 0.00000 -- onnxruntime/build/Windows/Release/Release/onnxruntime/capi/onnxruntime_inference_collection.py:101:__init__ (__init__)\n", " _create_inference_session -- 2 2 -- 0.02045 0.02055 -- onnxruntime/build/Windows/Release/Release/onnxruntime/capi/onnxruntime_inference_collection.py:346:_create_inference_session (_create_inference_session)\n", " check_and_nor...rovider_args -- 2 2 -- 0.00004 0.00008 -- onnxruntime/build/Windows/Release/Release/onnxruntime/capi/onnxruntime_inference_collection.py:25:check_and_normalize_provider_args (check_and_normalize_provider_args)\n", " set_provider_options -- 2 2 -- 0.00001 0.00001 -- onnxruntime/build/Windows/Release/Release/onnxruntime/capi/onnxruntime_inference_collection.py:53:set_provider_options (set_provider_options)\n", " -- 2 2 -- 0.00000 0.00000 -- onnxruntime/build/Windows/Release/Release/onnxruntime/capi/onnxruntime_inference_collection.py:62: ()\n", " -- 2 2 -- 0.00000 0.00000 -- onnxruntime/build/Windows/Release/Release/onnxruntime/capi/onnxruntime_inference_collection.py:75: ()\n", " -- 2 2 -- 0.00000 0.00000 -- onnxruntime/build/Windows/Release/Release/onnxruntime/capi/onnxruntime_inference_collection.py:78: ()\n", " load_model -- 2 2 -- 0.00001 0.00049 -- site-packages/onnx/__init__.py:107:load_model (load_model)\n", " _load_bytes -- 2 2 -- 0.00002 0.00003 -- site-packages/onnx/__init__.py:30:_load_bytes (_load_bytes)\n", " inner -- 4 4 -- 0.00000 0.00000 -- C:/Python395_x64/lib/typing.py:262:inner (inner) +++\n", " cast -- 4 4 -- 0.00000 0.00000 -- C:/Python395_x64/lib/typing.py:1333:cast (cast) +++\n", " _get_file_path -- 2 2 -- 0.00000 0.00000 -- site-packages/onnx/__init__.py:50:_get_file_path (_get_file_path)\n", " load_model_from_string -- 2 2 -- 0.00001 0.00045 -- site-packages/onnx/__init__.py:147:load_model_from_string (load_model_from_string)\n", " _deserialize -- 2 2 -- 0.00001 0.00044 -- site-packages/onnx/__init__.py:81:_deserialize (_deserialize)\n", " inner -- 2 2 -- 0.00000 0.00000 -- C:/Python395_x64/lib/typing.py:262:inner (inner) +++\n", " cast -- 2 2 -- 0.00000 0.00000 -- C:/Python395_x64/lib/typing.py:1333:cast (cast) +++\n", " -- 2 2 -- 0.00042 0.00042 -- ~:0: ()\n", " -- 16 16 -- 0.00000 0.00000 -- ~:0: () +++\n", " -- 1 1 -- 0.00014 0.00014 -- ~:0: ()\n", " new_instance -- 1 1 -- 0.00000 0.00000 -- onnxcustom/onnxcustom/training/ortgradient.py:211:new_instance (new_instance)\n", " __init__ -- 1 1 -- 0.00000 0.00000 -- onnxcustom/onnxcustom/training/ortgradient.py:501:__init__ (__init__)\n", " device_to_providers -- 1 1 -- 0.00003 0.00003 -- onnxruntime_helper.py:133:device_to_providers (device_to_providers)\n", " value -- 100 100 -- 0.00003 0.00003 -- onnxcustom/onnxcustom/training/sgd_learning_rate.py:160:value (value) +++\n", " init_learning_rate -- 1 1 -- 0.00000 0.00001 -- onnxcustom/onnxcustom/training/sgd_learning_rate.py:307:init_learning_rate (init_learning_rate)\n", " init_learning_rate -- 1 1 -- 0.00000 0.00000 -- onnxcustom/onnxcustom/training/sgd_learning_rate.py:176:init_learning_rate (init_learning_rate)\n", " update_learning_rate -- 100 100 -- 0.00015 0.00098 -- onnxcustom/onnxcustom/training/sgd_learning_rate.py:314:update_learning_rate (update_learning_rate)\n", " update_learning_rate -- 100 100 -- 0.00084 0.00084 -- onnxcustom/onnxcustom/training/sgd_learning_rate.py:194:update_learning_rate (update_learning_rate)\n", " proto_type_to_dtype -- 6 6 -- 0.00001 0.00001 -- onnxcustom/onnxcustom/utils/onnx_helper.py:53:proto_type_to_dtype (proto_type_to_dtype)\n", " -- 107 107 -- 0.00003 0.00003 -- ~:0: () +++\n", " -- 108 108 -- 0.00002 0.00002 -- ~:0: () +++\n", " -- 6 6 -- 0.00040 0.00040 -- ~:0: ()\n", "inner -- 6 6 -- 0.00001 0.00001 -- C:/Python395_x64/lib/typing.py:262:inner (inner)\n", "cast -- 6 6 -- 0.00000 0.00000 -- C:/Python395_x64/lib/typing.py:1333:cast (cast)\n", "_bio_cache -- 294800 294800 -- 0.18126 0.22052 -- onnxcustom/onnxcustom/training/_base_onnx_function.py:138:_bio_cache (_bio_cache)\n", " -- 294800 294800 -- 0.03926 0.03926 -- ~:0: () +++\n", "_bio_ptr -- 294800 294800 -- 0.20762 0.20762 -- onnxcustom/onnxcustom/training/_base_onnx_function.py:155:_bio_ptr (_bio_ptr)\n", "_bind_input_ortvalue -- 214400 214400 -- 0.21699 0.59239 -- onnxcustom/onnxcustom/training/_base_onnx_function.py:159:_bind_input_ortvalue (_bind_input_ortvalue)\n", " _bio_cache -- 214400 214400 -- 0.13709 0.16646 -- onnxcustom/onnxcustom/training/_base_onnx_function.py:138:_bio_cache (_bio_cache) +++\n", " _bio_do_bind_in -- 14000 14000 -- 0.03012 0.03012 -- onnxcustom/onnxcustom/training/_base_onnx_function.py:151:_bio_do_bind_in (_bio_do_bind_in)\n", " _bio_ptr -- 214400 214400 -- 0.15540 0.15540 -- onnxcustom/onnxcustom/training/_base_onnx_function.py:155:_bio_ptr (_bio_ptr) +++\n", " -- 214400 214400 -- 0.02341 0.02341 -- ~:0: () +++\n", "_get_att_state -- 205 205 -- 0.00007 0.00007 -- onnxcustom/onnxcustom/training/optimizers_partial.py:139:_get_att_state (_get_att_state)\n", "get_full_state -- 101 301 -- 0.00049 0.00093 -- onnxcustom/onnxcustom/training/optimizers_partial.py:147:get_full_state (get_full_state)\n", " _get_att_state -- 201 201 -- 0.00007 0.00007 -- onnxcustom/onnxcustom/training/optimizers_partial.py:139:_get_att_state (_get_att_state) +++\n", " -- 100 100 -- 0.00021 0.00072 -- onnxcustom/onnxcustom/training/optimizers_partial.py:152: ()\n", " get_full_state -- 200 200 -- 0.00030 0.00050 -- onnxcustom/onnxcustom/training/optimizers_partial.py:147:get_full_state (get_full_state) +++\n", " -- 201 201 -- 0.00004 0.00004 -- ~:0: () +++\n", " -- 201 201 -- 0.00005 0.00005 -- ~:0: () +++\n", " -- 301 301 -- 0.00007 0.00007 -- ~:0: () +++\n", "_provider_name_to_device_type -- 9 9 -- 0.00001 0.00001 -- onnxcustom/onnxcustom/training/ortgradient.py:260:_provider_name_to_device_type (_provider_name_to_device_type)\n", "get_initializer -- 14 14 -- 0.00008 0.00175 -- onnxcustom/onnxcustom/training/ortgradient.py:269:get_initializer (get_initializer)\n", " to_array -- 12 12 -- 0.00009 0.00168 -- site-packages/onnx/numpy_helper.py:21:to_array (to_array)\n", " uses_external_data -- 12 12 -- 0.00001 0.00001 -- site-packages/onnx/external_data_helper.py:224:uses_external_data (uses_external_data)\n", " -- 12 12 -- 0.00000 0.00000 -- ~:0: () +++\n", " -- 12 12 -- 0.00006 0.00006 -- ~:0: () +++\n", " -- 12 12 -- 0.00002 0.00002 -- ~:0: () +++\n", " -- 12 12 -- 0.00148 0.00148 -- ~:0: ()\n", " -- 12 12 -- 0.00001 0.00001 -- ~:0: () +++\n", " -- 24 24 -- 0.00001 0.00001 -- ~:0: () +++\n", "input_to_ort -- 13400 13400 -- 0.12150 0.16754 -- onnxcustom/onnxcustom/training/ortgradient.py:552:input_to_ort (input_to_ort)\n", " -- 13400 13400 -- 0.01681 0.03690 -- ~:0: () +++\n", " -- 13400 13400 -- 0.00712 0.00712 -- ~:0: () +++\n", " -- 13400 13400 -- 0.00202 0.00202 -- ~:0: () +++\n", "value -- 40300 40300 -- 0.00955 0.00955 -- onnxcustom/onnxcustom/training/sgd_learning_rate.py:160:value (value)\n", "numpy_to_ort_value -- 13426 13426 -- 0.00661 0.03872 -- onnxruntime_helper.py:120:numpy_to_ort_value (numpy_to_ort_value)\n", " -- 13426 13426 -- 0.03211 0.03211 -- ~:0: () +++\n", " -- 18 18 -- 0.00014 0.00014 -- ~:0: ()\n", " -- 13575 13575 -- 0.00608 0.00608 -- ~:0: ()\n", " -- 94120 94120 -- 0.01981 0.01981 -- ~:0: ()\n", " -- 362251 362251 -- 0.04476 0.04477 -- ~:0: ()\n", " __instancecheck__ -- 4 4 -- 0.00001 0.00001 -- C:/Python395_x64/lib/abc.py:96:__instancecheck__ (__instancecheck__)\n", " -- 67437 67437 -- 0.02341 0.02908 -- ~:0: ()\n", " __len__ -- 13600 13600 -- 0.00567 0.00567 -- onnxcustom/onnxcustom/training/data_loader.py:89:__len__ (__len__)\n", " -- 14 14 -- 0.00002 0.00002 -- ~:0: ()\n", " -- 213 213 -- 0.00005 0.00005 -- ~:0: ()\n", " -- 93826 93826 -- 0.19723 0.19723 -- ~:0: ()\n", " -- 301501 301501 -- 0.04083 0.04083 -- ~:0: ()\n", " -- 36 36 -- 0.00001 0.00001 -- ~:0: ()\n", " -- 13404 13404 -- 0.01681 0.03690 -- ~:0: ()\n", " -- 53600 53600 -- 0.01461 0.02009 -- onnxcustom/onnxcustom/training/ortgradient.py:572: ()\n"]}, {"name": "stdout", "output_type": "stream", "text": [" -- 53600 53600 -- 0.00548 0.00548 -- ~:0: () +++\n"]}], "source": ["def clean_name(text):\n", " pos = text.find('onnxruntime')\n", " if pos >= 0:\n", " return text[pos:]\n", " pos = text.find('sklearn')\n", " if pos >= 0:\n", " return text[pos:]\n", " pos = text.find('onnxcustom')\n", " if pos >= 0:\n", " return text[pos:]\n", " pos = text.find('site-packages')\n", " if pos >= 0:\n", " return text[pos:]\n", " return text\n", "\n", "from pyquickhelper.pycode.profiling import profile, profile2graph\n", "\n", "ps = profile(lambda:train_session2.fit(X, y))[0]\n", "root, nodes = profile2graph(ps, clean_text=clean_name)\n", "text = root.to_text()\n", "print(text)"]}, {"cell_type": "markdown", "id": "ac947d2d", "metadata": {}, "source": ["```\n", " _iteration -- 100 100 -- 0.41903 3.74610 -- \n", " iter_ortvalue -- 6800 6800 -- 0.02838 0.14761 -- \n", " _next_iter -- 6700 6700 -- 0.00946 0.07207 -- \n", " -- 6700 6700 -- 0.00245 0.00423 -- \n", " -- 6700 6700 -- 0.05838 0.05838 -- \n", " numpy_to_ort_value -- 13400 13400 -- 0.00658 0.03860 -- \n", " -- 6900 6900 -- 0.00467 0.00855 -- \n", " forward -- 6700 6700 -- 0.31685 0.44643 -- \n", " input_to_ort -- 6700 6700 -- 0.08002 0.11492 -- \n", " save_for_backward -- 6700 6700 -- 0.01032 0.01032 -- \n", " -- 6700 6700 -- 0.00434 0.00434 -- \n", " backward -- 6700 6700 -- 0.43012 0.48957 -- \n", " input_to_ort -- 6700 6700 -- 0.04148 0.05262 -- \n", " saved_tensors -- 6700 6700 -- 0.00207 0.00207 -- \n", " -- 6700 6700 -- 0.00476 0.00476 -- \n", " loss_gradient -- 6700 6700 -- 0.05841 0.26967 -- \n", " clear_binding_inputs -- 6700 6700 -- 0.00545 0.01270 -- \n", " _cache_in_clear -- 6700 6700 -- 0.00568 0.00725 -- \n", " -- 6700 6700 -- 0.00157 0.00157 -- \n", " _bind_input_ortvalue -- 13400 13400 -- 0.02070 0.07545 -- \n", " _call_iobinding -- 6700 6700 -- 0.11997 0.11997 -- \n", " -- 13400 13400 -- 0.00315 0.00315 -- \n", " penalty_loss -- 6700 6700 -- 0.00112 0.00112 -- \n", " update_weights -- 40200 40200 -- 0.00651 0.00651 -- \n", " update_weights -- 40200 40200 -- 0.40487 1.94238 -- \n", " _bind_input_ortvalue -- 201000 201000 -- 0.19630 0.51693 -- \n", " _bind_output_ortvalue -- 80400 80400 -- 0.07458 0.18952 -- \n", " _bio_cache -- 80400 80400 -- 0.04417 0.05406 -- \n", " _bio_ptr -- 80400 80400 -- 0.05222 0.05222 -- \n", " _bio_do_bind_out -- 12 12 -- 0.00003 0.00003 -- \n", " -- 80400 80400 -- 0.00863 0.00863 -- \n", " _call_iobinding -- 40200 40200 -- 0.63987 0.63987 -- \n", " value -- 40200 40200 -- 0.00953 0.00953 -- \n", " -- 80400 80400 -- 0.16512 0.16512 -- \n", " -- 80400 80400 -- 0.01655 0.01655 -- \n", " -- 100 100 -- 0.00026 0.00426 -- \n", " _mean -- 100 100 -- 0.00163 0.00400 -- \n", " _count_reduce_items -- 100 100 -- 0.00097 0.00107 -- \n", " -- 200 200 -- 0.00010 0.00010 -- \n", " -- 100 100 -- 0.00004 0.00004 -- \n", " -- 100 100 -- 0.00109 0.00109 -- \n", " -- 100 100 -- 0.00006 0.00006 -- \n", " -- 100 100 -- 0.00004 0.00004 -- \n", " -- 200 200 -- 0.00007 0.00007 -- \n", " -- 100 100 -- 0.00358 0.00358 -- \n", " -- 6700 6700 -- 0.00169 0.00169 -- \n", " -- 40300 40300 -- 0.01424 0.01424 -- \n", " _create_training_session -- 1 1 -- 0.00001 0.02824 -- \n", " __init__ -- 1 1 -- 0.00008 0.02820 -- \n", " -- 1 1 -- 0.00001 0.00001 -- \n", " -- 1 1 -- 0.00000 0.00000 -- \n", " -- 1 1 -- 0.00000 0.00000 -- \n", " _init_next -- 1 1 -- 0.00010 0.02809 -- \n", " -- 1 1 -- 0.00001 0.00001 -- \n", " -- 1 1 -- 0.00000 0.00000 -- \n", " -- 1 1 -- 0.00000 0.00000 -- \n", " _create_onnx_graphs -- 1 1 -- 0.00662 0.02797 -- \n", " -- 1 1 -- 0.00001 0.00001 -- \n", " -- 1 1 -- 0.00001 0.00001 -- \n", " -- 1 1 -- 0.00001 0.00002 -- \n", " _provider_name_to_device_type -- 1 1 -- 0.00000 0.00000 -- \n", " -- 1 1 -- 0.00002 0.00002 -- \n", " _provider_name_to_device_type -- 7 7 -- 0.00000 0.00000 -- \n", " -- 1 1 -- 0.00000 0.00000 -- \n", " _provider_name_to_device_type -- 1 1 -- 0.00000 0.00000 -- \n", " -- 1 1 -- 0.00001 0.00001 -- \n", " -- 1 1 -- 0.00001 0.00001 -- \n", " get_inputs -- 1 1 -- 0.00000 0.00000 -- \n", " get_outputs -- 1 1 -- 0.00000 0.00000 -- \n", " __init__ -- 2 2 -- 0.00004 0.02063 -- \n", " get -- 2 2 -- 0.00001 0.00004 -- \n", " __getitem__ -- 2 2 -- 0.00001 0.00003 -- \n", " encodekey -- 2 2 -- 0.00001 0.00002 -- \n", " check_str -- 2 2 -- 0.00000 0.00000 -- \n", " __init__ -- 2 2 -- 0.00000 0.00000 -- \n", " _create_inference_session -- 2 2 -- 0.02045 0.02055 -- \n", " check_and_nor...rovider_args -- 2 2 -- 0.00004 0.00008 -- \n", " set_provider_options -- 2 2 -- 0.00001 0.00001 -- \n", " -- 2 2 -- 0.00000 0.00000 -- \n", " -- 2 2 -- 0.00000 0.00000 -- \n", " -- 2 2 -- 0.00000 0.00000 -- \n", " load_model -- 2 2 -- 0.00001 0.00049 -- \n", " _load_bytes -- 2 2 -- 0.00002 0.00003 -- \n", " inner -- 4 4 -- 0.00000 0.00000 -- \n", " cast -- 4 4 -- 0.00000 0.00000 -- \n", " _get_file_path -- 2 2 -- 0.00000 0.00000 -- \n", " load_model_from_string -- 2 2 -- 0.00001 0.00045 -- \n", " _deserialize -- 2 2 -- 0.00001 0.00044 -- \n", " inner -- 2 2 -- 0.00000 0.00000 -- \n", " cast -- 2 2 -- 0.00000 0.00000 -- \n", " -- 2 2 -- 0.00042 0.00042 -- \n", " -- 16 16 -- 0.00000 0.00000 -- \n", " -- 1 1 -- 0.00014 0.00014 -- \n", " new_instance -- 1 1 -- 0.00000 0.00000 -- \n", " __init__ -- 1 1 -- 0.00000 0.00000 -- \n", " device_to_providers -- 1 1 -- 0.00003 0.00003 -- \n", " value -- 100 100 -- 0.00003 0.00003 -- \n", "\n", "```"]}, {"cell_type": "code", "execution_count": 49, "id": "af74c9ce", "metadata": {}, "outputs": [{"data": {"text/plain": ["{'model_onnx': 'mlp_onnx_ort\\\\GradFBOptimizer.model_onnx.onnx',\n", " 'learning_rate': {'axpyw_onnx_': 'mlp_onnx_ort\\\\LRateSGDNesterov.learning_rate.axpyw_onnx_.onnx'},\n", " 'learning_loss': {'loss_grad_onnx_': 'mlp_onnx_ort\\\\SquareLLoss.learning_loss.loss_grad_onnx_.onnx',\n", " 'loss_score_onnx_': 'mlp_onnx_ort\\\\SquareLLoss.learning_loss.loss_score_onnx_.onnx'},\n", " 'learning_penalty': {},\n", " 'zero_onnx_': 'mlp_onnx_ort\\\\GradFBOptimizer.zero_onnx_.onnx',\n", " 'train_function_': {'_trained_onnx': 'mlp_onnx_ort\\\\OrtGradientForwardBackwardFunction_1523278698000.train_function_._trained_onnx.onnx',\n", " '_optimized_pre_grad_model': 'mlp_onnx_ort\\\\OrtGradientForwardBackwardFunction_1523278698000.train_function_._optimized_pre_grad_model.onnx'}}"]}, "execution_count": 50, "metadata": {}, "output_type": "execute_result"}], "source": ["import os\n", "if not os.path.exists(\"mlp_onnx_ort\"):\n", " os.mkdir(\"mlp_onnx_ort\")\n", "train_session2.save_onnx_graph(\"mlp_onnx_ort\")"]}, {"cell_type": "markdown", "id": "1fce825b", "metadata": {}, "source": ["Weights are updated with the following ONNX graph:"]}, {"cell_type": "code", "execution_count": 50, "id": "bb84a7c6", "metadata": {}, "outputs": [{"data": {"text/html": ["
\n", ""], "text/plain": [""]}, "execution_count": 51, "metadata": {}, "output_type": "execute_result"}], "source": ["%onnxview train_session2.learning_rate.axpyw_onnx_"]}, {"cell_type": "code", "execution_count": 51, "id": "85bb6294", "metadata": {}, "outputs": [], "source": []}], "metadata": {"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}, "language_info": {"codemirror_mode": {"name": "ipython", "version": 3}, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.5"}}, "nbformat": 4, "nbformat_minor": 5}