{"cells": [{"cell_type": "markdown", "id": "4b855db5", "metadata": {}, "source": ["# NeuralTreeNet et ONNX\n", "\n", "La conversion d'un arbre de d\u00e9cision au format ONNX peut cr\u00e9er des diff\u00e9rences entre le mod\u00e8le original et le mod\u00e8le converti (voir [Issues when switching to float](http://www.xavierdupre.fr/app/onnxcustom/helpsphinx/gyexamples/plot_ebegin_float_double.html). Le probl\u00e8me vient d'un changement de type, les seuils de d\u00e9cisions sont arrondis au float32 le plus proche de leur valeur en float64 (double). Qu'advient-il si l'arbre de d\u00e9cision est converti en r\u00e9seau de neurones d'abord.\n", "\n", "L'approximation des seuils de d\u00e9cision ne change pas grand chose dans la majorit\u00e9 des cas. Cependant, il est possible que la comparaison d'une variable \u00e0 un seuil de d\u00e9cision arrondi soit l'oppos\u00e9 de celle avec le seuil non arrondi. Dans ce cas, la d\u00e9cision suit un chemin diff\u00e9rent dans l'arbre."]}, {"cell_type": "code", "execution_count": 1, "id": "636a122a", "metadata": {}, "outputs": [{"data": {"text/html": ["
run previous cell, wait for 2 seconds
\n", ""], "text/plain": [""]}, "execution_count": 2, "metadata": {}, "output_type": "execute_result"}], "source": ["from jyquickhelper import add_notebook_menu\n", "add_notebook_menu()"]}, {"cell_type": "code", "execution_count": 2, "id": "2f698cc0", "metadata": {}, "outputs": [], "source": ["%matplotlib inline"]}, {"cell_type": "code", "execution_count": 3, "id": "ad53d7c6", "metadata": {}, "outputs": [], "source": ["%load_ext mlprodict"]}, {"cell_type": "markdown", "id": "c7b2fb41", "metadata": {}, "source": ["## Jeu de donn\u00e9es\n", "\n", "On construit un jeu de donn\u00e9e al\u00e9atoire."]}, {"cell_type": "code", "execution_count": 4, "id": "a8feffa5", "metadata": {}, "outputs": [], "source": ["import numpy\n", "\n", "X = numpy.random.randn(10000, 10)\n", "y = X.sum(axis=1) / X.shape[1]\n", "X = X.astype(numpy.float64)\n", "y = y.astype(numpy.float64)"]}, {"cell_type": "code", "execution_count": 5, "id": "3c854905", "metadata": {}, "outputs": [], "source": ["middle = X.shape[0] // 2\n", "X_train, X_test = X[:middle], X[middle:]\n", "y_train, y_test = y[:middle], y[middle:]"]}, {"cell_type": "markdown", "id": "2972ef7f", "metadata": {}, "source": ["## Partie scikit-learn"]}, {"cell_type": "markdown", "id": "2a19a0c1", "metadata": {}, "source": ["### Caler un arbre de d\u00e9cision"]}, {"cell_type": "code", "execution_count": 6, "id": "bfc49123", "metadata": {}, "outputs": [{"data": {"text/plain": ["(0.6179766027481131, 0.33709933420465643)"]}, "execution_count": 7, "metadata": {}, "output_type": "execute_result"}], "source": ["from sklearn.tree import DecisionTreeRegressor\n", "\n", "tree = DecisionTreeRegressor(max_depth=7)\n", "tree.fit(X_train, y_train)\n", "tree.score(X_train, y_train), tree.score(X_test, y_test)"]}, {"cell_type": "code", "execution_count": 7, "id": "a38b0426", "metadata": {}, "outputs": [{"data": {"text/plain": ["0.33709933420465643"]}, "execution_count": 8, "metadata": {}, "output_type": "execute_result"}], "source": ["from sklearn.metrics import r2_score\n", "r2_score(y_test, tree.predict(X_test))"]}, {"cell_type": "markdown", "id": "86a0f0a3", "metadata": {}, "source": ["La profondeur de l'arbre est insuffisante mais ce n'est pas ce qui nous int\u00e9resse ici."]}, {"cell_type": "markdown", "id": "8e6038ff", "metadata": {}, "source": ["### Conversion au format ONNX"]}, {"cell_type": "code", "execution_count": 8, "id": "f6849a2d", "metadata": {}, "outputs": [], "source": ["from mlprodict.onnx_conv import to_onnx\n", "\n", "onx = to_onnx(tree, X[:1].astype(numpy.float32))"]}, {"cell_type": "code", "execution_count": 9, "id": "3daf9db1", "metadata": {}, "outputs": [{"data": {"text/plain": ["1.7421041873949668"]}, "execution_count": 10, "metadata": {}, "output_type": "execute_result"}], "source": ["from mlprodict.onnxrt import OnnxInference\n", "\n", "x_exp = X_test\n", "\n", "oinf = OnnxInference(onx, runtime='onnxruntime1')\n", "expected = tree.predict(x_exp)\n", "\n", "got = oinf.run({'X': x_exp.astype(numpy.float32)})['variable']\n", "numpy.abs(got - expected).max()"]}, {"cell_type": "code", "execution_count": 10, "id": "7ce247da", "metadata": {"scrolled": false}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["opset: domain='ai.onnx.ml' version=1\n", "opset: domain='' version=15\n", "input: name='X' type=dtype('float32') shape=[None, 10]\n", "TreeEnsembleRegressor(X, n_targets=1, nodes_falsenodeids=253:[128,65,34...252,0,0], nodes_featureids=253:[8,3,9...2,0,0], nodes_hitrates=253:[1.0,1.0...1.0,1.0], nodes_missing_value_tracks_true=253:[0,0,0...0,0,0], nodes_modes=253:[b'BRANCH_LEQ',b'BRANCH_LEQ'...b'LEAF',b'LEAF'], nodes_nodeids=253:[0,1,2...250,251,252], nodes_treeids=253:[0,0,0...0,0,0], nodes_truenodeids=253:[1,2,3...251,0,0], nodes_values=253:[0.00792999193072319,-0.12246682494878769...0.0,0.0], post_transform=b'NONE', target_ids=127:[0,0,0...0,0,0], target_nodeids=127:[7,8,10...249,251,252], target_treeids=127:[0,0,0...0,0,0], target_weights=127:[-0.9345570802688599,-0.6372960805892944...0.6169403195381165,1.0096807479858398]) -> variable\n", "output: name='variable' type=dtype('float32') shape=[None, 1]\n"]}], "source": ["from mlprodict.plotting.text_plot import onnx_simple_text_plot\n", "print(onnx_simple_text_plot(onx))"]}, {"cell_type": "markdown", "id": "1ada8e37", "metadata": {}, "source": ["## Apr\u00e8s la conversion en un r\u00e9seau de neurones"]}, {"cell_type": "markdown", "id": "7238d09b", "metadata": {}, "source": ["### Conversion en un r\u00e9seau de neurones\n", "\n", "Un param\u00e8tre permet de faire varier la pente des fonctions sigmo\u00efdes utilis\u00e9es."]}, {"cell_type": "code", "execution_count": 11, "id": "7729c242", "metadata": {}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 18/18 [00:01<00:00, 12.49it/s]\n"]}], "source": ["from tqdm import tqdm\n", "from pandas import DataFrame\n", "from mlstatpy.ml.neural_tree import NeuralTreeNet\n", "\n", "xe = x_exp[:500]\n", "expected = tree.predict(xe)\n", "\n", "data = []\n", "trees = {}\n", "for i in tqdm([0.3, 0.4, 0.5, 0.7, 0.9, 1] + list(range(5, 61, 5))):\n", " root = NeuralTreeNet.create_from_tree(tree, k=i, arch='compact')\n", " got = root.predict(xe)[:, -1]\n", " me = numpy.abs(got - expected).mean()\n", " mx = numpy.abs(got - expected).max()\n", " obs = dict(k=i, max=mx, mean=me)\n", " data.append(obs)\n", " trees[i] = root"]}, {"cell_type": "code", "execution_count": 12, "id": "9d35377e", "metadata": {}, "outputs": [{"data": {"text/html": ["
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
kmaxmean
00.30.5689810.158758
10.40.6083040.132576
20.50.6926570.128525
30.70.7805430.131497
40.90.8098660.128368
51.00.8138890.124802
65.00.3924820.022466
710.00.3417490.006350
815.00.2706490.002939
920.00.2997130.002110
1025.00.3054930.001842
1130.00.3061110.001767
1235.00.2993710.001665
1340.00.2335560.001011
1445.00.2336060.000801
1550.00.2336140.000547
1655.00.2336150.000499
1760.00.2336150.000484
\n", "
"], "text/plain": [" k max mean\n", "0 0.3 0.568981 0.158758\n", "1 0.4 0.608304 0.132576\n", "2 0.5 0.692657 0.128525\n", "3 0.7 0.780543 0.131497\n", "4 0.9 0.809866 0.128368\n", "5 1.0 0.813889 0.124802\n", "6 5.0 0.392482 0.022466\n", "7 10.0 0.341749 0.006350\n", "8 15.0 0.270649 0.002939\n", "9 20.0 0.299713 0.002110\n", "10 25.0 0.305493 0.001842\n", "11 30.0 0.306111 0.001767\n", "12 35.0 0.299371 0.001665\n", "13 40.0 0.233556 0.001011\n", "14 45.0 0.233606 0.000801\n", "15 50.0 0.233614 0.000547\n", "16 55.0 0.233615 0.000499\n", "17 60.0 0.233615 0.000484"]}, "execution_count": 13, "metadata": {}, "output_type": "execute_result"}], "source": ["df = DataFrame(data)\n", "df"]}, {"cell_type": "code", "execution_count": 13, "id": "0fcb9789", "metadata": {}, "outputs": [{"data": {"image/png": "\n", "text/plain": ["
"]}, "metadata": {"needs_background": "light"}, "output_type": "display_data"}], "source": ["df.set_index('k').plot(title=\"Pr\u00e9cision de la conversion\\nen r\u00e9seau de neurones\");"]}, {"cell_type": "markdown", "id": "1f4bb3d9", "metadata": {}, "source": ["L'erreur est meilleure mais il faudrait recommencer l'exp\u00e9rience plusieurs fois avant de pouvoir conclure afin d'obtenir un interval de confiance pour le m\u00eame type de jeu de donn\u00e9es. Ce sera pour une autre fois. Le r\u00e9sultat d\u00e9pend du jeu de donn\u00e9es et surtout de la proximit\u00e9 des seuils de d\u00e9cisions. N\u00e9anmoins, on calcule l'erreur sur l'ensemble de la base de test. Celle-ci a \u00e9t\u00e9 tronqu\u00e9e pour aller plus vite."]}, {"cell_type": "code", "execution_count": 14, "id": "2f3eb6d0", "metadata": {}, "outputs": [{"data": {"text/plain": ["(0.2336143002078063, 0.0002511855017989173)"]}, "execution_count": 15, "metadata": {}, "output_type": "execute_result"}], "source": ["expected = tree.predict(x_exp)\n", "got = trees[50].predict(x_exp)[:, -1]\n", "numpy.abs(got - expected).max(), numpy.abs(got - expected).mean()"]}, {"cell_type": "markdown", "id": "77163512", "metadata": {}, "source": ["On voit que l'erreur peut-\u00eatre tr\u00e8s grande. Elle reste n\u00e9anmoins plus petite que l'erreur de conversion introduite par ONNX."]}, {"cell_type": "markdown", "id": "738c8547", "metadata": {}, "source": ["### Conversion au format ONNX\n", "\n", "On cr\u00e9e tout d'abord une classe qui suit l'API de scikit-learn et qui englobe l'arbre qui vient d'\u00eatre cr\u00e9\u00e9 qui sera ensuite convertit en ONNX."]}, {"cell_type": "code", "execution_count": 15, "id": "2439e4fa", "metadata": {}, "outputs": [], "source": ["from mlstatpy.ml.neural_tree import NeuralTreeNetRegressor\n", "\n", "reg = NeuralTreeNetRegressor(trees[50])\n", "onx2 = to_onnx(reg, X[:1].astype(numpy.float32))"]}, {"cell_type": "code", "execution_count": 16, "id": "eae47e6a", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["opset: domain='' version=15\n", "input: name='X' type=dtype('float32') shape=[None, 10]\n", "init: name='Ma_MatMulcst' type=dtype('float32') shape=(1260,)\n", "init: name='Ad_Addcst' type=dtype('float32') shape=(126,)\n", "init: name='Mu_Mulcst' type=dtype('float32') shape=(1,) -- array([4.], dtype=float32)\n", "init: name='Ma_MatMulcst1' type=dtype('float32') shape=(16002,)\n", "init: name='Ad_Addcst1' type=dtype('float32') shape=(127,)\n", "init: name='Ma_MatMulcst2' type=dtype('float32') shape=(127,)\n", "init: name='Ad_Addcst2' type=dtype('float32') shape=(1,) -- array([0.], dtype=float32)\n", "MatMul(X, Ma_MatMulcst) -> Ma_Y02\n", " Add(Ma_Y02, Ad_Addcst) -> Ad_C02\n", " Mul(Ad_C02, Mu_Mulcst) -> Mu_C01\n", " Sigmoid(Mu_C01) -> Si_Y01\n", " MatMul(Si_Y01, Ma_MatMulcst1) -> Ma_Y01\n", " Add(Ma_Y01, Ad_Addcst1) -> Ad_C01\n", " Mul(Ad_C01, Mu_Mulcst) -> Mu_C0\n", " Sigmoid(Mu_C0) -> Si_Y0\n", " MatMul(Si_Y0, Ma_MatMulcst2) -> Ma_Y0\n", " Add(Ma_Y0, Ad_Addcst2) -> Ad_C0\n", " Identity(Ad_C0) -> variable\n", "output: name='variable' type=dtype('float32') shape=[None, 1]\n"]}], "source": ["print(onnx_simple_text_plot(onx2))"]}, {"cell_type": "code", "execution_count": 17, "id": "1d4e272f", "metadata": {}, "outputs": [{"data": {"text/plain": ["1.7421041873949668"]}, "execution_count": 18, "metadata": {}, "output_type": "execute_result"}], "source": ["oinf2 = OnnxInference(onx2, runtime='onnxruntime1')\n", "expected = tree.predict(x_exp)\n", "\n", "got = oinf2.run({'X': x_exp.astype(numpy.float32)})['variable']\n", "numpy.abs(got - expected).max()"]}, {"cell_type": "markdown", "id": "f4e64f63", "metadata": {}, "source": ["L'erreur est la m\u00eame."]}, {"cell_type": "markdown", "id": "c9207392", "metadata": {}, "source": ["## Temps de calcul"]}, {"cell_type": "code", "execution_count": 18, "id": "a6febd37", "metadata": {}, "outputs": [], "source": ["x_exp32 = x_exp.astype(numpy.float32)"]}, {"cell_type": "markdown", "id": "1bf0109e", "metadata": {}, "source": ["Tout d'abord le temps de calcul pour scikit-learn."]}, {"cell_type": "code", "execution_count": 19, "id": "07caad53", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["513 \u00b5s \u00b1 7.52 \u00b5s per loop (mean \u00b1 std. dev. of 7 runs, 1,000 loops each)\n"]}], "source": ["%timeit tree.predict(x_exp32)"]}, {"cell_type": "markdown", "id": "0cea5139", "metadata": {}, "source": ["Le temps de calcul pour l'arbre de d\u00e9cision au format ONNX."]}, {"cell_type": "code", "execution_count": 20, "id": "984413fa", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["186 \u00b5s \u00b1 3.41 \u00b5s per loop (mean \u00b1 std. dev. of 7 runs, 10,000 loops each)\n"]}], "source": ["%timeit oinf.run({'X': x_exp32})['variable']"]}, {"cell_type": "markdown", "id": "afb4f6bb", "metadata": {}, "source": ["Et le temps de calcul pour le r\u00e9seau de neurones au format ONNX.m"]}, {"cell_type": "code", "execution_count": 21, "id": "e3268dcd", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["3.75 ms \u00b1 311 \u00b5s per loop (mean \u00b1 std. dev. of 7 runs, 100 loops each)\n"]}], "source": ["%timeit oinf2.run({'X': x_exp32})['variable']"]}, {"cell_type": "markdown", "id": "b3eafba0", "metadata": {}, "source": ["Ce temps de calcul tr\u00e8s long est attendu car le mod\u00e8le contient une multiplication de matrice tr\u00e8s grande et surtout que tous les seuils de l'arbre sont calcul\u00e9s pour chaque observation. L\u00e0 o\u00f9 l'impl\u00e9mentation de l'arbre de d\u00e9cision calcule *d* seuils, la profondeur de l'arbre, la nouvelle impl\u00e9mentation calcule tous les seuils soit $2^d$ pour chaque feuille. Il y a $2^d$ feuilles. M\u00eame en \u00e9tant sparse, on peut r\u00e9duire les calculs \u00e0 $d * 2^d$ ce qui fait encore beaucoup de calculs inutiles."]}, {"cell_type": "code", "execution_count": 22, "id": "d9911fff", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["(126, 11) (126,)\n", "(127, 127) (127,)\n", "(128,) ()\n"]}], "source": ["for node in trees[50].nodes:\n", " print(node.coef.shape, node.bias.shape)"]}, {"cell_type": "markdown", "id": "27e187ac", "metadata": {}, "source": ["Cela dit, la plus grande matrice est creuse, elle peut \u00eatre r\u00e9duite consid\u00e9rablement."]}, {"cell_type": "code", "execution_count": 23, "id": "e97479fe", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["coef.shape=(126, 11), size dense=1386, size sparse=252, ratio=0.18181818181818182\n", "coef.shape=(127, 127), size dense=16129, size sparse=1015, ratio=0.06293012586025172\n", "coef.shape=(128,), size dense=128, size sparse=127, ratio=0.9921875\n"]}], "source": ["from scipy.sparse import csr_matrix\n", "\n", "for node in trees[50].nodes:\n", " csr = csr_matrix(node.coef)\n", " print(f\"coef.shape={node.coef.shape}, size dense={node.coef.size}, \"\n", " f\"size sparse={csr.size}, ratio={csr.size / node.coef.size}\")"]}, {"cell_type": "code", "execution_count": 24, "id": "125547d9", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["49.8 \u00b5s \u00b1 1.25 \u00b5s per loop (mean \u00b1 std. dev. of 7 runs, 10,000 loops each)\n"]}], "source": ["r = numpy.random.randn(trees[50].nodes[1].coef.shape[0])\n", "mat = trees[50].nodes[1].coef\n", "%timeit mat @ r"]}, {"cell_type": "code", "execution_count": 25, "id": "ad7173e5", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["7.08 \u00b5s \u00b1 173 ns per loop (mean \u00b1 std. dev. of 7 runs, 100,000 loops each)\n"]}], "source": ["csr = csr_matrix(mat)\n", "%timeit csr @ r"]}, {"cell_type": "markdown", "id": "7599d94e", "metadata": {}, "source": ["Ce serait beaucoup plus rapide avec une matrice sparse et d'autant plus rapide que l'arbre est profond. Le mod\u00e8le ONNX se d\u00e9compose comme suit."]}, {"cell_type": "code", "execution_count": 26, "id": "0c1839fd", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["opset: domain='' version=15\n", "input: name='X' type=dtype('float32') shape=[None, 10]\n", "init: name='Ma_MatMulcst' type=dtype('float32') shape=(1260,)\n", "init: name='Ad_Addcst' type=dtype('float32') shape=(126,)\n", "init: name='Mu_Mulcst' type=dtype('float32') shape=(1,) -- array([4.], dtype=float32)\n", "init: name='Ma_MatMulcst1' type=dtype('float32') shape=(16002,)\n", "init: name='Ad_Addcst1' type=dtype('float32') shape=(127,)\n", "init: name='Ma_MatMulcst2' type=dtype('float32') shape=(127,)\n", "init: name='Ad_Addcst2' type=dtype('float32') shape=(1,) -- array([0.], dtype=float32)\n", "MatMul(X, Ma_MatMulcst) -> Ma_Y02\n", " Add(Ma_Y02, Ad_Addcst) -> Ad_C02\n", " Mul(Ad_C02, Mu_Mulcst) -> Mu_C01\n", " Sigmoid(Mu_C01) -> Si_Y01\n", " MatMul(Si_Y01, Ma_MatMulcst1) -> Ma_Y01\n", " Add(Ma_Y01, Ad_Addcst1) -> Ad_C01\n", " Mul(Ad_C01, Mu_Mulcst) -> Mu_C0\n", " Sigmoid(Mu_C0) -> Si_Y0\n", " MatMul(Si_Y0, Ma_MatMulcst2) -> Ma_Y0\n", " Add(Ma_Y0, Ad_Addcst2) -> Ad_C0\n", " Identity(Ad_C0) -> variable\n", "output: name='variable' type=dtype('float32') shape=[None, 1]\n"]}], "source": ["print(onnx_simple_text_plot(onx2))"]}, {"cell_type": "markdown", "id": "318b95d7", "metadata": {}, "source": ["Voyons comment le temps de calcul se r\u00e9partit."]}, {"cell_type": "code", "execution_count": 27, "id": "11bccd22", "metadata": {}, "outputs": [], "source": ["oinfpr = OnnxInference(onx2, runtime=\"onnxruntime1\",\n", " runtime_options={\"enable_profiling\": True})\n", "for i in range(0, 43):\n", " oinfpr.run({\"X\": x_exp32})"]}, {"cell_type": "code", "execution_count": 28, "id": "5485970b", "metadata": {}, "outputs": [{"data": {"text/html": ["
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
catpidtiddurtsphnameargs_op_nameargs_parameter_sizeargs_graph_indexargs_providerargs_exec_plan_indexargs_activation_sizeargs_output_sizeargs_input_type_shapeargs_output_type_shapeargs_thread_scheduling_stats
0Session7811688203874Xmodel_loading_arrayNaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
1Session7811688202532428Xsession_initializationNaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
2Node78116882003294Xgemm_fence_beforeGemmNaNNaNNaNNaNNaNNaNNaNNaNNaN
3Node78116882013153300Xgemm_kernel_timeGemm554411CPUExecutionProvider112000002520000[{'float': [5000, 10]}, {'float': [10, 126]}, ...[{'float': [5000, 126]}]{'main_thread': {'thread_pool_name': 'session-...
4Node78116882004635Xgemm_fence_afterGemmNaNNaNNaNNaNNaNNaNNaNNaNNaN
......................................................
986Node7811688200210170XMa_MatMul2_fence_beforeMatMulNaNNaNNaNNaNNaNNaNNaNNaNNaN
987Node781168820124210172XMa_MatMul2_kernel_timeMatMul5088CPUExecutionProvider8254000020000[{'float': [5000, 127]}, {'float': [127, 1]}][{'float': [5000, 1]}]{'main_thread': {'thread_pool_name': 'session-...
988Node7811688200210305XMa_MatMul2_fence_afterMatMulNaNNaNNaNNaNNaNNaNNaNNaNNaN
989Session7811688204378205930XSequentialExecutor::ExecuteNaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
990Session7811688204388205925Xmodel_runNaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
\n", "

991 rows \u00d7 17 columns

\n", "
"], "text/plain": [" cat pid tid dur ts ph name \\\n", "0 Session 78116 8820 387 4 X model_loading_array \n", "1 Session 78116 8820 2532 428 X session_initialization \n", "2 Node 78116 8820 0 3294 X gemm_fence_before \n", "3 Node 78116 8820 1315 3300 X gemm_kernel_time \n", "4 Node 78116 8820 0 4635 X gemm_fence_after \n", ".. ... ... ... ... ... .. ... \n", "986 Node 78116 8820 0 210170 X Ma_MatMul2_fence_before \n", "987 Node 78116 8820 124 210172 X Ma_MatMul2_kernel_time \n", "988 Node 78116 8820 0 210305 X Ma_MatMul2_fence_after \n", "989 Session 78116 8820 4378 205930 X SequentialExecutor::Execute \n", "990 Session 78116 8820 4388 205925 X model_run \n", "\n", " args_op_name args_parameter_size args_graph_index args_provider \\\n", "0 NaN NaN NaN NaN \n", "1 NaN NaN NaN NaN \n", "2 Gemm NaN NaN NaN \n", "3 Gemm 5544 11 CPUExecutionProvider \n", "4 Gemm NaN NaN NaN \n", ".. ... ... ... ... \n", "986 MatMul NaN NaN NaN \n", "987 MatMul 508 8 CPUExecutionProvider \n", "988 MatMul NaN NaN NaN \n", "989 NaN NaN NaN NaN \n", "990 NaN NaN NaN NaN \n", "\n", " args_exec_plan_index args_activation_size args_output_size \\\n", "0 NaN NaN NaN \n", "1 NaN NaN NaN \n", "2 NaN NaN NaN \n", "3 11 200000 2520000 \n", "4 NaN NaN NaN \n", ".. ... ... ... \n", "986 NaN NaN NaN \n", "987 8 2540000 20000 \n", "988 NaN NaN NaN \n", "989 NaN NaN NaN \n", "990 NaN NaN NaN \n", "\n", " args_input_type_shape \\\n", "0 NaN \n", "1 NaN \n", "2 NaN \n", "3 [{'float': [5000, 10]}, {'float': [10, 126]}, ... \n", "4 NaN \n", ".. ... \n", "986 NaN \n", "987 [{'float': [5000, 127]}, {'float': [127, 1]}] \n", "988 NaN \n", "989 NaN \n", "990 NaN \n", "\n", " args_output_type_shape \\\n", "0 NaN \n", "1 NaN \n", "2 NaN \n", "3 [{'float': [5000, 126]}] \n", "4 NaN \n", ".. ... \n", "986 NaN \n", "987 [{'float': [5000, 1]}] \n", "988 NaN \n", "989 NaN \n", "990 NaN \n", "\n", " args_thread_scheduling_stats \n", "0 NaN \n", "1 NaN \n", "2 NaN \n", "3 {'main_thread': {'thread_pool_name': 'session-... \n", "4 NaN \n", ".. ... \n", "986 NaN \n", "987 {'main_thread': {'thread_pool_name': 'session-... \n", "988 NaN \n", "989 NaN \n", "990 NaN \n", "\n", "[991 rows x 17 columns]"]}, "execution_count": 29, "metadata": {}, "output_type": "execute_result"}], "source": ["df = oinfpr.get_profiling(as_df=True)\n", "df"]}, {"cell_type": "code", "execution_count": 29, "id": "19bb5d0f", "metadata": {}, "outputs": [{"data": {"text/plain": ["{'CPUExecutionProvider', nan}"]}, "execution_count": 30, "metadata": {}, "output_type": "execute_result"}], "source": ["set(df['args_provider'])"]}, {"cell_type": "code", "execution_count": 30, "id": "e42d5644", "metadata": {}, "outputs": [{"data": {"text/html": ["
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
dur
args_op_namename
MatMulMa_MatMul26778
MulMu_Mul12923
SigmoidSi_Sigmoid14849
MulMu_Mul115151
SigmoidSi_Sigmoid115608
Gemmgemm31763
gemm_token_099047
\n", "
"], "text/plain": [" dur\n", "args_op_name name \n", "MatMul Ma_MatMul2 6778\n", "Mul Mu_Mul 12923\n", "Sigmoid Si_Sigmoid 14849\n", "Mul Mu_Mul1 15151\n", "Sigmoid Si_Sigmoid1 15608\n", "Gemm gemm 31763\n", " gemm_token_0 99047"]}, "execution_count": 31, "metadata": {}, "output_type": "execute_result"}], "source": ["dfp = df[df.args_provider == 'CPUExecutionProvider'].copy()\n", "dfp['name'] = dfp['name'].apply(lambda s: s.replace(\"_kernel_time\", \"\"))\n", "gr_dur = dfp[['dur', \"args_op_name\", \"name\"]].groupby([\"args_op_name\", \"name\"]).sum().sort_values('dur')\n", "gr_dur"]}, {"cell_type": "code", "execution_count": 31, "id": "34b33616", "metadata": {}, "outputs": [{"data": {"text/html": ["
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
dur
args_op_namename
MatMulMa_MatMul243
MulMu_Mul43
SigmoidSi_Sigmoid43
MulMu_Mul143
SigmoidSi_Sigmoid143
Gemmgemm43
gemm_token_043
\n", "
"], "text/plain": [" dur\n", "args_op_name name \n", "MatMul Ma_MatMul2 43\n", "Mul Mu_Mul 43\n", "Sigmoid Si_Sigmoid 43\n", "Mul Mu_Mul1 43\n", "Sigmoid Si_Sigmoid1 43\n", "Gemm gemm 43\n", " gemm_token_0 43"]}, "execution_count": 32, "metadata": {}, "output_type": "execute_result"}], "source": ["gr_n = dfp[['dur', \"args_op_name\", \"name\"]].groupby([\"args_op_name\", \"name\"]).count().sort_values('dur')\n", "gr_n = gr_n.loc[gr_dur.index, :]\n", "gr_n"]}, {"cell_type": "code", "execution_count": 32, "id": "f34b2908", "metadata": {}, "outputs": [{"data": {"image/png": "\n", "text/plain": ["
"]}, "metadata": {"needs_background": "light"}, "output_type": "display_data"}], "source": ["import matplotlib.pyplot as plt\n", "\n", "fig, ax = plt.subplots(1, 2, figsize=(12, 4))\n", "gr_dur.plot.barh(ax=ax[0])\n", "gr_n.plot.barh(ax=ax[1])\n", "ax[0].set_title(\"duration\")\n", "ax[1].set_title(\"n occurences\");"]}, {"cell_type": "markdown", "id": "7b10ca8a", "metadata": {}, "source": ["onnxruntime passe principalement son temps dans un produit matriciel. On v\u00e9rifie plus pr\u00e9cis\u00e9ment."]}, {"cell_type": "code", "execution_count": 33, "id": "4cbc2fa0", "metadata": {}, "outputs": [{"data": {"text/html": ["
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
12712
catNodeNode
pid7811678116
tid88208820
dur46034083
ts371735949
phXX
namegemm_token_0_kernel_timegemm_token_0_kernel_time
args_op_nameGemmGemm
args_parameter_size6451664516
args_graph_index1212
args_providerCPUExecutionProviderCPUExecutionProvider
args_exec_plan_index1212
args_activation_size25200002520000
args_output_size25400002540000
args_input_type_shape[{'float': [5000, 126]}, {'float': [126, 127]}...[{'float': [5000, 126]}, {'float': [126, 127]}...
args_output_type_shape[{'float': [5000, 127]}][{'float': [5000, 127]}]
args_thread_scheduling_stats{'main_thread': {'thread_pool_name': 'session-...{'main_thread': {'thread_pool_name': 'session-...
\n", "
"], "text/plain": [" 127 \\\n", "cat Node \n", "pid 78116 \n", "tid 8820 \n", "dur 4603 \n", "ts 37173 \n", "ph X \n", "name gemm_token_0_kernel_time \n", "args_op_name Gemm \n", "args_parameter_size 64516 \n", "args_graph_index 12 \n", "args_provider CPUExecutionProvider \n", "args_exec_plan_index 12 \n", "args_activation_size 2520000 \n", "args_output_size 2540000 \n", "args_input_type_shape [{'float': [5000, 126]}, {'float': [126, 127]}... \n", "args_output_type_shape [{'float': [5000, 127]}] \n", "args_thread_scheduling_stats {'main_thread': {'thread_pool_name': 'session-... \n", "\n", " 12 \n", "cat Node \n", "pid 78116 \n", "tid 8820 \n", "dur 4083 \n", "ts 5949 \n", "ph X \n", "name gemm_token_0_kernel_time \n", "args_op_name Gemm \n", "args_parameter_size 64516 \n", "args_graph_index 12 \n", "args_provider CPUExecutionProvider \n", "args_exec_plan_index 12 \n", "args_activation_size 2520000 \n", "args_output_size 2540000 \n", "args_input_type_shape [{'float': [5000, 126]}, {'float': [126, 127]}... \n", "args_output_type_shape [{'float': [5000, 127]}] \n", "args_thread_scheduling_stats {'main_thread': {'thread_pool_name': 'session-... "]}, "execution_count": 34, "metadata": {}, "output_type": "execute_result"}], "source": ["df[(df.args_op_name == 'Gemm') & (df.dur > 0)].sort_values('dur', ascending=False).head(n=2).T"]}, {"cell_type": "markdown", "id": "58320942", "metadata": {}, "source": ["C'est un produit matriciel d'environ *5000x800* par *800x800*."]}, {"cell_type": "code", "execution_count": 34, "id": "de43df2f", "metadata": {}, "outputs": [{"data": {"text/html": ["
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
dur
args_op_namename
MatMulMa_MatMul20.034561
MulMu_Mul0.065894
SigmoidSi_Sigmoid0.075714
MulMu_Mul10.077254
SigmoidSi_Sigmoid10.079584
Gemmgemm0.161958
gemm_token_00.505035
\n", "
"], "text/plain": [" dur\n", "args_op_name name \n", "MatMul Ma_MatMul2 0.034561\n", "Mul Mu_Mul 0.065894\n", "Sigmoid Si_Sigmoid 0.075714\n", "Mul Mu_Mul1 0.077254\n", "Sigmoid Si_Sigmoid1 0.079584\n", "Gemm gemm 0.161958\n", " gemm_token_0 0.505035"]}, "execution_count": 35, "metadata": {}, "output_type": "execute_result"}], "source": ["gr_dur / gr_dur.dur.sum()"]}, {"cell_type": "code", "execution_count": 35, "id": "0e5c02ec", "metadata": {}, "outputs": [{"data": {"text/plain": ["0.5050352082154203"]}, "execution_count": 36, "metadata": {}, "output_type": "execute_result"}], "source": ["r = (gr_dur / gr_dur.dur.sum()).dur.max()\n", "r"]}, {"cell_type": "markdown", "id": "113a480a", "metadata": {}, "source": ["Il occupe 82% du temps. et d'apr\u00e8s l'exp\u00e9rience pr\u00e9c\u00e9dente, son temps d'\u00e9xecution peut-\u00eatre r\u00e9duit par 10 en le rempla\u00e7ant par une matrice sparse. Cela ne suffira pas pour acc\u00e9lerer le temps de calcul de ce r\u00e9seau de neurones. Il est 84 ms compar\u00e9 \u00e0 247 \u00b5s pour l'arbre de d\u00e9cision. Avec cette optimisation, il pourrait passer de :"]}, {"cell_type": "code", "execution_count": 36, "id": "fa7950bc", "metadata": {}, "outputs": [{"data": {"text/plain": ["2.013941471759493"]}, "execution_count": 37, "metadata": {}, "output_type": "execute_result"}], "source": ["t = 3.75 # ms\n", "t * (1 - r) + r * t / 12"]}, {"cell_type": "markdown", "id": "7c641d19", "metadata": {}, "source": ["Soit une r\u00e9duction du temps de calcul. Ce n'est pas mal mais pas assez."]}, {"cell_type": "markdown", "id": "535b7e56", "metadata": {}, "source": ["## Hummingbird\n", "\n", "[hummingbird](https://github.com/microsoft/hummingbird) est une librairie qui convertit un arbre de d\u00e9cision en r\u00e9seau de neurones. Voyons ses performances."]}, {"cell_type": "code", "execution_count": 37, "id": "3b3aa43b", "metadata": {}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["C:\\xavierdupre\\__home_\\github_fork\\scikit-learn\\sklearn\\utils\\deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead.\n", " warnings.warn(msg, category=FutureWarning)\n"]}, {"data": {"text/plain": ["(4.3419181139370266e-08, 4.430287026515114e-09)"]}, "execution_count": 38, "metadata": {}, "output_type": "execute_result"}], "source": ["from hummingbird.ml import convert\n", "\n", "model = convert(tree, 'torch')\n", "\n", "expected = tree.predict(x_exp)\n", "got = model.predict(x_exp)\n", "numpy.abs(got - expected).max(), numpy.abs(got - expected).mean()"]}, {"cell_type": "markdown", "id": "92365d70", "metadata": {}, "source": ["Le r\u00e9sultat est beaucoup plus fid\u00e8le au mod\u00e8le."]}, {"cell_type": "code", "execution_count": 38, "id": "605df039", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["1.17 ms \u00b1 34.8 \u00b5s per loop (mean \u00b1 std. dev. of 7 runs, 1,000 loops each)\n"]}], "source": ["%timeit model.predict(x_exp)"]}, {"cell_type": "markdown", "id": "c2f80290", "metadata": {}, "source": ["Il reste plus lent mais beaucoup plus rapide que la solution manuelle propos\u00e9e dans les pr\u00e9c\u00e9dents paragraphes. Il contient un attribut `model`."]}, {"cell_type": "code", "execution_count": 39, "id": "e77ff4f0", "metadata": {}, "outputs": [{"data": {"text/plain": ["True"]}, "execution_count": 40, "metadata": {}, "output_type": "execute_result"}], "source": ["from torch.nn import Module\n", "isinstance(model.model, Module)"]}, {"cell_type": "markdown", "id": "871277df", "metadata": {}, "source": ["On convertit ce mod\u00e8le au format ONNX."]}, {"cell_type": "code", "execution_count": 40, "id": "3c875b35", "metadata": {}, "outputs": [], "source": ["import torch.onnx\n", "\n", "x = torch.randn(x_exp.shape[0], x_exp.shape[1], requires_grad=True)\n", "torch.onnx.export(model.model, x, 'tree_torch.onnx', opset_version=15, \n", " input_names=['X'], output_names=['variable'],\n", " dynamic_axes={\n", " 'X' : {0 : 'batch_size'},\n", " 'variable' : {0 : 'batch_size'}})"]}, {"cell_type": "code", "execution_count": 41, "id": "b8c41c5e", "metadata": {}, "outputs": [], "source": ["import onnx\n", "\n", "onxh = onnx.load('tree_torch.onnx')"]}, {"cell_type": "code", "execution_count": 42, "id": "861a94d0", "metadata": {"scrolled": false}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["opset: domain='' version=15\n", "input: name='X' type=dtype('float32') shape=['batch_size', 10]\n", "init: name='_operators.0.root_nodes' type=dtype('int64') shape=(0,) -- array([8], dtype=int64)\n", "init: name='_operators.0.root_biases' type=dtype('float32') shape=(0,) -- array([0.00792999], dtype=float32)\n", "init: name='_operators.0.tree_indices' type=dtype('int64') shape=(0,) -- array([0], dtype=int64)\n", "init: name='_operators.0.leaf_nodes' type=dtype('float32') shape=(0,) -- array([ 1.0096807 , 0.6169403 , 0.61055773, 0.37810475, 0.31796893,\n", " 0.13317925, 0.0193846 , -0.2317742 , 0.39089343, 0.23506087,\n", " 0.3711936 , 0.10317916, 0.14956598, -0.14193445, -0.05965868,\n", " -0.27377078, 0.4128183 , 0.19658326, 0.25545415, 0.08118545,\n", " 0.08400188, -0.1502193 , -0.36846825, -0.79687625, 0.35822242,\n", " 0.49021915, 0.30870998, 0.01033915, 0.6740977 , 0.6740977 ,\n", " -0.15315758, -0.41128033, 0.42920846, 0.13145493, 0.21853392,\n", " -0.10986731, 0.4493652 , 0.11318789, 0.12666471, -0.0623082 ,\n", " 0.2872893 , 0.09948976, 0.11439473, -0.08801427, 0.16091613,\n", " -0.02319027, -0.10097775, -0.37583745, 0.18612385, -0.00453244,\n", " 0.3287116 , -0.1499349 , 0.7919218 , 0.04704398, -0.15423109,\n", " -0.43160027, 0.10802375, -0.1073833 , -0.07759219, -0.29175794,\n", " -0.1528881 , -0.4909434 , -0.23361537, -0.43578717, 0.7831867 ,\n", " 0.45349318, 0.34956965, -0.3199535 , 0.3061573 , -0.34267113,\n", " 0.34963542, 0.04491445, 0.35399815, 0.14815213, 0.06678926,\n", " -0.16095412, 0.3214274 , 0.01484008, -0.1012276 , -0.3257699 ,\n", " 0.26727676, 0.01970094, 0.10760042, -0.09169976, 0.20044112,\n", " -0.0324069 , -0.11015374, -0.28358367, 0.8083656 , 0.13358633,\n", " -0.07912118, -0.27182895, -0.07054728, -0.24895027, -0.20600456,\n", " -0.42033467, 0.34701794, -0.0638995 , 0.14252576, -0.06025055,\n", " 0.4228329 , 0.06789401, 0.03919645, -0.17267554, 0.07274943,\n", " -0.487512 , 0.04517636, -0.18857062, -0.03975222, -0.2652712 ,\n", " -0.30853328, -0.50844556, 0.03321444, -0.15481217, -0.20701212,\n", " -0.40578464, -0.25884995, -0.46550158, -0.4797585 , -0.7324234 ,\n", " 0.43939307, -0.06170902, -0.51546025, -0.19215119, -0.3705445 ,\n", " -0.57504356, -0.6372961 , -0.9345571 ], dtype=float32)\n", "init: name='_operators.0.nodes.0' type=dtype('int64') shape=(0,) -- array([0, 3], dtype=int64)\n", "init: name='_operators.0.nodes.1' type=dtype('int64') shape=(0,) -- array([1, 2, 5, 9], dtype=int64)\n", "init: name='_operators.0.nodes.2' type=dtype('int64') shape=(0,) -- array([5, 6, 3, 7, 2, 0, 7, 1], dtype=int64)\n", "init: name='_operators.0.nodes.3' type=dtype('int64') shape=(0,) -- array([3, 9, 5, 3, 6, 4, 1, 3, 6, 6, 1, 6, 5, 4, 6, 2], dtype=int64)\n", "init: name='_operators.0.nodes.4' type=dtype('int64') shape=(0,) -- array([3, 2, 7, 6, 2, 4, 7, 8, 9, 5, 7, 8, 9, 4, 6, 9, 7, 9, 0, 7, 7, 9,\n", " 2, 7, 6, 4, 6, 5, 4, 0, 6, 0], dtype=int64)\n", "init: name='_operators.0.nodes.5' type=dtype('int64') shape=(0,) -- array([2, 8, 7, 6, 6, 3, 4, 9, 7, 3, 2, 6, 3, 3, 0, 1, 1, 0, 4, 7, 9, 5,\n", " 7, 9, 5, 3, 5, 9, 0, 5, 1, 4, 9, 4, 7, 7, 1, 9, 1, 1, 6, 2, 7, 7,\n", " 6, 1, 4, 4, 0, 0, 9, 8, 8, 2, 6, 2, 0, 3, 4, 2, 5, 6, 7, 3],\n", " dtype=int64)\n", "init: name='_operators.0.biases.0' type=dtype('float32') shape=(0,) -- array([ 0.19169255, -0.12246682], dtype=float32)\n", "init: name='_operators.0.biases.1' type=dtype('float32') shape=(0,) -- array([-0.40610337, -0.1467492 , -0.01880287, 0.15879431], dtype=float32)\n", "init: name='_operators.0.biases.2' type=dtype('float32') shape=(0,) -- array([ 0.736786 , -0.32427853, 0.30860555, 0.17994082, 0.6917758 ,\n", " -0.00594712, 0.35950053, -0.9819274 ], dtype=float32)\n", "init: name='_operators.0.biases.3' type=dtype('float32') shape=(0,) -- array([-1.3495584 , -1.082793 , -0.6906011 , -0.08978076, -0.4007622 ,\n", " 0.10756078, -0.68507075, 0.15814054, 0.5132364 , -0.18426335,\n", " 0.13685235, 0.10721841, 0.01814443, -0.41644228, -0.59770894,\n", " 0.607365 ], dtype=float32)\n", "init: name='_operators.0.biases.4' type=dtype('float32') shape=(0,) -- array([ 1.4203796 , -0.49269757, -0.12210988, -0.09692484, 0.5076643 ,\n", " -1.3609421 , 1.154743 , 2.8748922 , -0.08181615, 0.7741028 ,\n", " 0.20604724, 0.666296 , -0.6474025 , 0.6459148 , 0.02262808,\n", " -0.42282397, 0.46360654, -0.10058792, 0.25486696, 0.60041225,\n", " -0.06933744, 0.21294908, 0.96443814, 0.07923891, 0.4797698 ,\n", " 1.2852331 , 0.24348404, -0.3404966 , -0.07175394, -0.8248828 ,\n", " -0.74071133, -1.2140133 ], dtype=float32)\n", "init: name='_operators.0.biases.5' type=dtype('float32') shape=(0,) -- array([ 1.0626682 , 1.4745288 , 0.01898679, 0.5451088 , 0.15444604,\n", " 1.0631477 , -0.7555804 , -1.7192128 , -0.20905146, 0.19752283,\n", " -0.40471953, 0.13069782, 0.60331047, 1.5060809 , 0. ,\n", " -1.8283446 , -0.8124372 , -1.381897 , 0.59209645, 0.3239226 ,\n", " -0.42840806, -0.43624896, 0.58229303, -1.0196047 , -0.5632828 ,\n", " 0.91483426, 1.8038778 , -0.5665638 , -1.2530733 , -0.6500004 ,\n", " -1.3069727 , 0.48267984, 0.73503745, -1.871724 , -1.4965518 ,\n", " 1.3147466 , 0.03919952, -0.885836 , 0.5479692 , -0.8086383 ,\n", " -0.74240863, 0.14582941, 0.6496967 , -0.00911551, 2.4541488 ,\n", " -0.90482277, 0.26108736, 0.7569448 , -1.0786855 , -0.45229852,\n", " 1.2146595 , -0.6756766 , -2.3066258 , 0.7911504 , 0.57490873,\n", " -0.40741247, 0.24633038, -1.2022957 , -0.65162694, -0.04244827,\n", " 1.558136 , -1.6220782 , 0.1574643 , -1.4209061 ], dtype=float32)\n", "Constant(value=[-1]) -> onnx::Reshape_27\n", "Gather(X, _operators.0.root_nodes, axis=1) -> onnx::LessOrEqual_17\n", " LessOrEqual(onnx::LessOrEqual_17, _operators.0.root_biases) -> onnx::Cast_18\n", " Cast(onnx::Cast_18, to=7) -> onnx::Add_19\n", " Add(onnx::Add_19, _operators.0.tree_indices) -> onnx::Reshape_20\n", "Constant(value=[-1]) -> onnx::Reshape_21\n", " Reshape(onnx::Reshape_20, onnx::Reshape_21, allowzero=0) -> onnx::Gather_22\n", " Gather(_operators.0.nodes.0, onnx::Gather_22, axis=0) -> onnx::Reshape_23\n", "Constant(value=[-1, 1]) -> onnx::Reshape_24\n", " Reshape(onnx::Reshape_23, onnx::Reshape_24, allowzero=0) -> onnx::GatherElements_25\n", " GatherElements(X, onnx::GatherElements_25, axis=1) -> onnx::Reshape_26\n", " Reshape(onnx::Reshape_26, onnx::Reshape_27, allowzero=0) -> onnx::LessOrEqual_28\n", "Constant(value=2) -> onnx::Mul_29\n", " Mul(onnx::Gather_22, onnx::Mul_29) -> onnx::Add_30\n", "Gather(_operators.0.biases.0, onnx::Gather_22, axis=0) -> onnx::LessOrEqual_31\n", " LessOrEqual(onnx::LessOrEqual_28, onnx::LessOrEqual_31) -> onnx::Cast_32\n", " Cast(onnx::Cast_32, to=7) -> onnx::Add_33\n", " Add(onnx::Add_30, onnx::Add_33) -> onnx::Gather_34\n", " Gather(_operators.0.nodes.1, onnx::Gather_34, axis=0) -> onnx::Reshape_35\n", "Constant(value=[-1, 1]) -> onnx::Reshape_36\n", " Reshape(onnx::Reshape_35, onnx::Reshape_36, allowzero=0) -> onnx::GatherElements_37\n", " GatherElements(X, onnx::GatherElements_37, axis=1) -> onnx::Reshape_38\n", "Constant(value=[-1]) -> onnx::Reshape_39\n", " Reshape(onnx::Reshape_38, onnx::Reshape_39, allowzero=0) -> onnx::LessOrEqual_40\n", "Constant(value=2) -> onnx::Mul_41\n", " Mul(onnx::Gather_34, onnx::Mul_41) -> onnx::Add_42\n", "Gather(_operators.0.biases.1, onnx::Gather_34, axis=0) -> onnx::LessOrEqual_43\n", " LessOrEqual(onnx::LessOrEqual_40, onnx::LessOrEqual_43) -> onnx::Cast_44\n", " Cast(onnx::Cast_44, to=7) -> onnx::Add_45\n", " Add(onnx::Add_42, onnx::Add_45) -> onnx::Gather_46\n", " Gather(_operators.0.nodes.2, onnx::Gather_46, axis=0) -> onnx::Reshape_47\n", "Constant(value=[-1, 1]) -> onnx::Reshape_48\n", " Reshape(onnx::Reshape_47, onnx::Reshape_48, allowzero=0) -> onnx::GatherElements_49\n", " GatherElements(X, onnx::GatherElements_49, axis=1) -> onnx::Reshape_50\n", "Constant(value=[-1]) -> onnx::Reshape_51\n", " Reshape(onnx::Reshape_50, onnx::Reshape_51, allowzero=0) -> onnx::LessOrEqual_52\n", "Constant(value=2) -> onnx::Mul_53\n", " Mul(onnx::Gather_46, onnx::Mul_53) -> onnx::Add_54\n", "Gather(_operators.0.biases.2, onnx::Gather_46, axis=0) -> onnx::LessOrEqual_55\n", " LessOrEqual(onnx::LessOrEqual_52, onnx::LessOrEqual_55) -> onnx::Cast_56\n", " Cast(onnx::Cast_56, to=7) -> onnx::Add_57\n", " Add(onnx::Add_54, onnx::Add_57) -> onnx::Gather_58\n", " Gather(_operators.0.nodes.3, onnx::Gather_58, axis=0) -> onnx::Reshape_59\n", "Constant(value=[-1, 1]) -> onnx::Reshape_60\n", " Reshape(onnx::Reshape_59, onnx::Reshape_60, allowzero=0) -> onnx::GatherElements_61\n", " GatherElements(X, onnx::GatherElements_61, axis=1) -> onnx::Reshape_62\n", "Constant(value=[-1]) -> onnx::Reshape_63\n", " Reshape(onnx::Reshape_62, onnx::Reshape_63, allowzero=0) -> onnx::LessOrEqual_64\n", "Constant(value=2) -> onnx::Mul_65\n", " Mul(onnx::Gather_58, onnx::Mul_65) -> onnx::Add_66\n", "Gather(_operators.0.biases.3, onnx::Gather_58, axis=0) -> onnx::LessOrEqual_67\n", " LessOrEqual(onnx::LessOrEqual_64, onnx::LessOrEqual_67) -> onnx::Cast_68\n", " Cast(onnx::Cast_68, to=7) -> onnx::Add_69\n", " Add(onnx::Add_66, onnx::Add_69) -> onnx::Gather_70\n", " Gather(_operators.0.nodes.4, onnx::Gather_70, axis=0) -> onnx::Reshape_71\n", "Constant(value=[-1, 1]) -> onnx::Reshape_72\n", " Reshape(onnx::Reshape_71, onnx::Reshape_72, allowzero=0) -> onnx::GatherElements_73\n", " GatherElements(X, onnx::GatherElements_73, axis=1) -> onnx::Reshape_74\n", "Constant(value=[-1]) -> onnx::Reshape_75\n", " Reshape(onnx::Reshape_74, onnx::Reshape_75, allowzero=0) -> onnx::LessOrEqual_76\n", "Constant(value=2) -> onnx::Mul_77\n", " Mul(onnx::Gather_70, onnx::Mul_77) -> onnx::Add_78\n", "Gather(_operators.0.biases.4, onnx::Gather_70, axis=0) -> onnx::LessOrEqual_79\n", " LessOrEqual(onnx::LessOrEqual_76, onnx::LessOrEqual_79) -> onnx::Cast_80\n", " Cast(onnx::Cast_80, to=7) -> onnx::Add_81\n", " Add(onnx::Add_78, onnx::Add_81) -> onnx::Gather_82\n", " Gather(_operators.0.nodes.5, onnx::Gather_82, axis=0) -> onnx::Reshape_83\n", "Constant(value=[-1, 1]) -> onnx::Reshape_84\n", " Reshape(onnx::Reshape_83, onnx::Reshape_84, allowzero=0) -> onnx::GatherElements_85\n", " GatherElements(X, onnx::GatherElements_85, axis=1) -> onnx::Reshape_86\n", "Constant(value=[-1]) -> onnx::Reshape_87\n", " Reshape(onnx::Reshape_86, onnx::Reshape_87, allowzero=0) -> onnx::LessOrEqual_88\n", "Constant(value=2) -> onnx::Mul_89\n", " Mul(onnx::Gather_82, onnx::Mul_89) -> onnx::Add_90\n", "Gather(_operators.0.biases.5, onnx::Gather_82, axis=0) -> onnx::LessOrEqual_91\n", " LessOrEqual(onnx::LessOrEqual_88, onnx::LessOrEqual_91) -> onnx::Cast_92\n", " Cast(onnx::Cast_92, to=7) -> onnx::Add_93\n", " Add(onnx::Add_90, onnx::Add_93) -> onnx::Gather_94\n", " Gather(_operators.0.leaf_nodes, onnx::Gather_94, axis=0) -> onnx::Reshape_95\n", "Constant(value=[-1, 1, 1]) -> onnx::Reshape_96\n", " Reshape(onnx::Reshape_95, onnx::Reshape_96, allowzero=0) -> output\n", "Constant(value=[1]) -> onnx::ReduceSum_98\n", " ReduceSum(output, onnx::ReduceSum_98, keepdims=0) -> variable\n"]}, {"name": "stdout", "output_type": "stream", "text": ["output: name='variable' type=dtype('float32') shape=['batch_size', 'ReduceSumvariable_dim_1']\n"]}], "source": ["print(onnx_simple_text_plot(onxh, raise_exc=False))"]}, {"cell_type": "code", "execution_count": 43, "id": "822bfa80", "metadata": {}, "outputs": [{"data": {"text/html": ["
\n", ""], "text/plain": [""]}, "execution_count": 44, "metadata": {}, "output_type": "execute_result"}], "source": ["%onnxview onxh"]}, {"cell_type": "markdown", "id": "1edb6177", "metadata": {}, "source": ["La librairie r\u00e9impl\u00e9mente la d\u00e9cision d'un arbre d\u00e9cision \u00e0 partir d'un produit matriciel pour chaque niveau de l'arbre. Tous les seuils sont \u00e9valu\u00e9s. Les matrices n'ont pas besoin d'\u00eatre sparses car les features n\u00e9cessaires sont r\u00e9cup\u00e9r\u00e9es. Le seuil de d\u00e9cision est impl\u00e9ment\u00e9 avec un test et non une sigmo\u00efde. Ce mod\u00e8le est donc identique en terme de pr\u00e9diction au mod\u00e8le initial."]}, {"cell_type": "code", "execution_count": 44, "id": "2220ca2e", "metadata": {}, "outputs": [{"data": {"text/plain": ["1.7421041873949668"]}, "execution_count": 45, "metadata": {}, "output_type": "execute_result"}], "source": ["oinfh = OnnxInference(onxh, runtime='onnxruntime1')\n", "expected = tree.predict(x_exp)\n", "\n", "got = oinfh.run({'X': x_exp.astype(numpy.float32)})['variable']\n", "numpy.abs(got - expected).max()"]}, {"cell_type": "markdown", "id": "10de2a80", "metadata": {}, "source": ["La conversion reste imparfaite \u00e9galement."]}, {"cell_type": "code", "execution_count": 45, "id": "fd13b28b", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["3.13 ms \u00b1 445 \u00b5s per loop (mean \u00b1 std. dev. of 7 runs, 100 loops each)\n"]}], "source": ["%timeit oinfh.run({'X': x_exp32})['variable']"]}, {"cell_type": "markdown", "id": "11a36a32", "metadata": {}, "source": ["Et le temps de calcul est aussi plus long."]}, {"cell_type": "markdown", "id": "20afcc41", "metadata": {}, "source": ["## Apprentissage\n", "\n", "L'id\u00e9e derri\u00e8re tout cela est aussi de pouvoir r\u00e9estimer les coefficients du r\u00e9seau de neurones une fois converti."]}, {"cell_type": "code", "execution_count": 46, "id": "96abfddb", "metadata": {}, "outputs": [], "source": ["x_train = X_train[:100]\n", "expected = tree.predict(x_train)\n", "reg = NeuralTreeNetRegressor(trees[1], verbose=1, max_iter=10, lr=1e-4)"]}, {"cell_type": "code", "execution_count": 47, "id": "94dc4d66", "metadata": {}, "outputs": [{"data": {"text/plain": ["(1.0246115055833722, 0.24094382754240642)"]}, "execution_count": 48, "metadata": {}, "output_type": "execute_result"}], "source": ["got = reg.predict(x_train)\n", "numpy.abs(got - expected).max(), numpy.abs(got - expected).mean()"]}, {"cell_type": "markdown", "id": "111970a1", "metadata": {}, "source": ["La diff\u00e9rence est grande."]}, {"cell_type": "code", "execution_count": 48, "id": "a50b3384", "metadata": {"scrolled": false}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["0/10: loss: 3.201 lr=0.0001 max(coef): 6.5 l1=0/1.5e+03 l2=0/2.5e+03\n", "1/10: loss: 2.593 lr=9.95e-06 max(coef): 6.5 l1=2e+03/1.5e+03 l2=1.3e+03/2.5e+03\n", "2/10: loss: 2.506 lr=7.05e-06 max(coef): 6.5 l1=1.4e+02/1.5e+03 l2=6.2/2.5e+03\n", "3/10: loss: 2.461 lr=5.76e-06 max(coef): 6.5 l1=1.2e+03/1.5e+03 l2=6.8e+02/2.5e+03\n", "4/10: loss: 2.429 lr=4.99e-06 max(coef): 6.5 l1=6.5e+02/1.5e+03 l2=2.1e+02/2.5e+03\n", "5/10: loss: 2.405 lr=4.47e-06 max(coef): 6.5 l1=1.9e+02/1.5e+03 l2=13/2.5e+03\n", "6/10: loss: 2.392 lr=4.08e-06 max(coef): 6.5 l1=1.6e+02/1.5e+03 l2=6.8/2.5e+03\n", "7/10: loss: 2.375 lr=3.78e-06 max(coef): 6.5 l1=1.8e+02/1.5e+03 l2=9.5/2.5e+03\n", "8/10: loss: 2.358 lr=3.53e-06 max(coef): 6.5 l1=1.1e+02/1.5e+03 l2=7/2.5e+03\n", "9/10: loss: 2.345 lr=3.33e-06 max(coef): 6.5 l1=3.7e+02/1.5e+03 l2=56/2.5e+03\n", "10/10: loss: 2.333 lr=3.16e-06 max(coef): 6.5 l1=6.1e+02/1.5e+03 l2=1.3e+02/2.5e+03\n"]}, {"data": {"text/html": ["
NeuralTreeNetRegressor(estimator=None, lr=0.0001, max_iter=10, verbose=1)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
"], "text/plain": ["NeuralTreeNetRegressor(estimator=None, lr=0.0001, max_iter=10, verbose=1)"]}, "execution_count": 49, "metadata": {}, "output_type": "execute_result"}], "source": ["reg.fit(x_train, expected)"]}, {"cell_type": "code", "execution_count": 49, "id": "c3ae49b2", "metadata": {}, "outputs": [{"data": {"text/plain": ["(1.256860512819292, 0.25663312220721907)"]}, "execution_count": 50, "metadata": {}, "output_type": "execute_result"}], "source": ["got = reg.predict(x_train)\n", "numpy.abs(got - expected).max(), numpy.abs(got - expected).mean()"]}, {"cell_type": "markdown", "id": "831e538f", "metadata": {}, "source": ["Ca ne marche pas aussi bien que pr\u00e9vu. Il faudrait sans doute plusieurs it\u00e9rations et jouer avec les param\u00e8tres d'apprentissage."]}, {"cell_type": "code", "execution_count": 50, "id": "6cfe39bd", "metadata": {}, "outputs": [], "source": []}], "metadata": {"kernelspec": {"display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3"}, "language_info": {"codemirror_mode": {"name": "ipython", "version": 3}, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.5"}}, "nbformat": 4, "nbformat_minor": 5}