{"cells": [{"cell_type": "markdown", "id": "4b855db5", "metadata": {}, "source": ["# NeuralTreeNet et ONNX\n", "\n", "La conversion d'un arbre de d\u00e9cision au format ONNX peut cr\u00e9er des diff\u00e9rences entre le mod\u00e8le original et le mod\u00e8le converti (voir [Issues when switching to float](http://www.xavierdupre.fr/app/onnxcustom/helpsphinx/gyexamples/plot_ebegin_float_double.html). Le probl\u00e8me vient d'un changement de type, les seuils de d\u00e9cisions sont arrondis au float32 le plus proche de leur valeur en float64 (double). Qu'advient-il si l'arbre de d\u00e9cision est converti en r\u00e9seau de neurones d'abord.\n", "\n", "L'approximation des seuils de d\u00e9cision ne change pas grand chose dans la majorit\u00e9 des cas. Cependant, il est possible que la comparaison d'une variable \u00e0 un seuil de d\u00e9cision arrondi soit l'oppos\u00e9 de celle avec le seuil non arrondi. Dans ce cas, la d\u00e9cision suit un chemin diff\u00e9rent dans l'arbre."]}, {"cell_type": "code", "execution_count": 1, "id": "636a122a", "metadata": {}, "outputs": [{"data": {"text/html": ["
\n", ""], "text/plain": ["\n", " | k | \n", "max | \n", "mean | \n", "
---|---|---|---|
0 | \n", "0.3 | \n", "0.568981 | \n", "0.158758 | \n", "
1 | \n", "0.4 | \n", "0.608304 | \n", "0.132576 | \n", "
2 | \n", "0.5 | \n", "0.692657 | \n", "0.128525 | \n", "
3 | \n", "0.7 | \n", "0.780543 | \n", "0.131497 | \n", "
4 | \n", "0.9 | \n", "0.809866 | \n", "0.128368 | \n", "
5 | \n", "1.0 | \n", "0.813889 | \n", "0.124802 | \n", "
6 | \n", "5.0 | \n", "0.392482 | \n", "0.022466 | \n", "
7 | \n", "10.0 | \n", "0.341749 | \n", "0.006350 | \n", "
8 | \n", "15.0 | \n", "0.270649 | \n", "0.002939 | \n", "
9 | \n", "20.0 | \n", "0.299713 | \n", "0.002110 | \n", "
10 | \n", "25.0 | \n", "0.305493 | \n", "0.001842 | \n", "
11 | \n", "30.0 | \n", "0.306111 | \n", "0.001767 | \n", "
12 | \n", "35.0 | \n", "0.299371 | \n", "0.001665 | \n", "
13 | \n", "40.0 | \n", "0.233556 | \n", "0.001011 | \n", "
14 | \n", "45.0 | \n", "0.233606 | \n", "0.000801 | \n", "
15 | \n", "50.0 | \n", "0.233614 | \n", "0.000547 | \n", "
16 | \n", "55.0 | \n", "0.233615 | \n", "0.000499 | \n", "
17 | \n", "60.0 | \n", "0.233615 | \n", "0.000484 | \n", "
\n", " | cat | \n", "pid | \n", "tid | \n", "dur | \n", "ts | \n", "ph | \n", "name | \n", "args_op_name | \n", "args_parameter_size | \n", "args_graph_index | \n", "args_provider | \n", "args_exec_plan_index | \n", "args_activation_size | \n", "args_output_size | \n", "args_input_type_shape | \n", "args_output_type_shape | \n", "args_thread_scheduling_stats | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "Session | \n", "78116 | \n", "8820 | \n", "387 | \n", "4 | \n", "X | \n", "model_loading_array | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "
1 | \n", "Session | \n", "78116 | \n", "8820 | \n", "2532 | \n", "428 | \n", "X | \n", "session_initialization | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "
2 | \n", "Node | \n", "78116 | \n", "8820 | \n", "0 | \n", "3294 | \n", "X | \n", "gemm_fence_before | \n", "Gemm | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "
3 | \n", "Node | \n", "78116 | \n", "8820 | \n", "1315 | \n", "3300 | \n", "X | \n", "gemm_kernel_time | \n", "Gemm | \n", "5544 | \n", "11 | \n", "CPUExecutionProvider | \n", "11 | \n", "200000 | \n", "2520000 | \n", "[{'float': [5000, 10]}, {'float': [10, 126]}, ... | \n", "[{'float': [5000, 126]}] | \n", "{'main_thread': {'thread_pool_name': 'session-... | \n", "
4 | \n", "Node | \n", "78116 | \n", "8820 | \n", "0 | \n", "4635 | \n", "X | \n", "gemm_fence_after | \n", "Gemm | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
986 | \n", "Node | \n", "78116 | \n", "8820 | \n", "0 | \n", "210170 | \n", "X | \n", "Ma_MatMul2_fence_before | \n", "MatMul | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "
987 | \n", "Node | \n", "78116 | \n", "8820 | \n", "124 | \n", "210172 | \n", "X | \n", "Ma_MatMul2_kernel_time | \n", "MatMul | \n", "508 | \n", "8 | \n", "CPUExecutionProvider | \n", "8 | \n", "2540000 | \n", "20000 | \n", "[{'float': [5000, 127]}, {'float': [127, 1]}] | \n", "[{'float': [5000, 1]}] | \n", "{'main_thread': {'thread_pool_name': 'session-... | \n", "
988 | \n", "Node | \n", "78116 | \n", "8820 | \n", "0 | \n", "210305 | \n", "X | \n", "Ma_MatMul2_fence_after | \n", "MatMul | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "
989 | \n", "Session | \n", "78116 | \n", "8820 | \n", "4378 | \n", "205930 | \n", "X | \n", "SequentialExecutor::Execute | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "
990 | \n", "Session | \n", "78116 | \n", "8820 | \n", "4388 | \n", "205925 | \n", "X | \n", "model_run | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "
991 rows \u00d7 17 columns
\n", "\n", " | \n", " | dur | \n", "
---|---|---|
args_op_name | \n", "name | \n", "\n", " |
MatMul | \n", "Ma_MatMul2 | \n", "6778 | \n", "
Mul | \n", "Mu_Mul | \n", "12923 | \n", "
Sigmoid | \n", "Si_Sigmoid | \n", "14849 | \n", "
Mul | \n", "Mu_Mul1 | \n", "15151 | \n", "
Sigmoid | \n", "Si_Sigmoid1 | \n", "15608 | \n", "
Gemm | \n", "gemm | \n", "31763 | \n", "
gemm_token_0 | \n", "99047 | \n", "
\n", " | \n", " | dur | \n", "
---|---|---|
args_op_name | \n", "name | \n", "\n", " |
MatMul | \n", "Ma_MatMul2 | \n", "43 | \n", "
Mul | \n", "Mu_Mul | \n", "43 | \n", "
Sigmoid | \n", "Si_Sigmoid | \n", "43 | \n", "
Mul | \n", "Mu_Mul1 | \n", "43 | \n", "
Sigmoid | \n", "Si_Sigmoid1 | \n", "43 | \n", "
Gemm | \n", "gemm | \n", "43 | \n", "
gemm_token_0 | \n", "43 | \n", "
\n", " | 127 | \n", "12 | \n", "
---|---|---|
cat | \n", "Node | \n", "Node | \n", "
pid | \n", "78116 | \n", "78116 | \n", "
tid | \n", "8820 | \n", "8820 | \n", "
dur | \n", "4603 | \n", "4083 | \n", "
ts | \n", "37173 | \n", "5949 | \n", "
ph | \n", "X | \n", "X | \n", "
name | \n", "gemm_token_0_kernel_time | \n", "gemm_token_0_kernel_time | \n", "
args_op_name | \n", "Gemm | \n", "Gemm | \n", "
args_parameter_size | \n", "64516 | \n", "64516 | \n", "
args_graph_index | \n", "12 | \n", "12 | \n", "
args_provider | \n", "CPUExecutionProvider | \n", "CPUExecutionProvider | \n", "
args_exec_plan_index | \n", "12 | \n", "12 | \n", "
args_activation_size | \n", "2520000 | \n", "2520000 | \n", "
args_output_size | \n", "2540000 | \n", "2540000 | \n", "
args_input_type_shape | \n", "[{'float': [5000, 126]}, {'float': [126, 127]}... | \n", "[{'float': [5000, 126]}, {'float': [126, 127]}... | \n", "
args_output_type_shape | \n", "[{'float': [5000, 127]}] | \n", "[{'float': [5000, 127]}] | \n", "
args_thread_scheduling_stats | \n", "{'main_thread': {'thread_pool_name': 'session-... | \n", "{'main_thread': {'thread_pool_name': 'session-... | \n", "
\n", " | \n", " | dur | \n", "
---|---|---|
args_op_name | \n", "name | \n", "\n", " |
MatMul | \n", "Ma_MatMul2 | \n", "0.034561 | \n", "
Mul | \n", "Mu_Mul | \n", "0.065894 | \n", "
Sigmoid | \n", "Si_Sigmoid | \n", "0.075714 | \n", "
Mul | \n", "Mu_Mul1 | \n", "0.077254 | \n", "
Sigmoid | \n", "Si_Sigmoid1 | \n", "0.079584 | \n", "
Gemm | \n", "gemm | \n", "0.161958 | \n", "
gemm_token_0 | \n", "0.505035 | \n", "
NeuralTreeNetRegressor(estimator=None, lr=0.0001, max_iter=10, verbose=1)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
NeuralTreeNetRegressor(estimator=None, lr=0.0001, max_iter=10, verbose=1)