SerializeToString and ParseFromString#

Startup#

This section creates an ONNX graph if there is not one.

import numpy
import onnx
from pyquickhelper.pycode.profiling import profile, profile2graph
from cpyquickhelper.numbers.speed_measure import measure_time
import matplotlib.pyplot as plt
import pandas
from tqdm import tqdm
from onnx.checker import check_model
from mlprodict.testing.experimental_c_impl.experimental_c import code_optimisation
from mlprodict.plotting.text_plot import onnx_simple_text_plot
from mlprodict.npy.xop import loadop
try:
    from mlprodict.onnx_tools._onnx_check_model import check_model as check_model_py
except ImportError:
    check_model_py = None

Available optimisation on this machine.

print(code_optimisation())
AVX-omp=8

Build an ONNX graph of different size#

def build_model(n_nodes, size, opv=15):
    OnnxAdd, OnnxIdentity = loadop('Add', 'Identity')
    x = 'X'
    for n in range(n_nodes):
        y = OnnxAdd(x, numpy.random.randn(size).astype(numpy.float32),
                    op_version=opv)
        x = y
    final = OnnxIdentity(x, op_version=opv, output_names=['Y'])
    x = numpy.zeros((10, 10), dtype=numpy.float32)
    return final.to_onnx({'X': x}, {'Y': x}, target_opset=opv)


model = build_model(2, 5)
print(onnx_simple_text_plot(model))
opset: domain='' version=15
input: name='X' type=dtype('float32') shape=[10, 10]
init: name='init' type=dtype('float32') shape=(5,)
init: name='init_1' type=dtype('float32') shape=(5,)
Add(X, init) -> out_add_0
  Add(out_add_0, init_1) -> Y
output: name='Y' type=dtype('float32') shape=[10, 10]

Measure the time of serialization functions#

def parse(buffer):
    proto = onnx.ModelProto()
    proto.ParseFromString(buffer)
    return proto


data = []
nodes = [5, 10, 20]
for size in tqdm([10, 100, 1000, 10000, 100000, 200000, 300000]):
    for n_nodes in nodes:
        repeat = 20 if size < 100000 else 5
        onx = build_model(n_nodes, size)
        serialized = onx.SerializeToString()
        onnx_size = len(serialized)
        obs = measure_time(lambda: onx.SerializeToString(),
                           div_by_number=True, repeat=repeat)
        obs['size'] = size
        obs['n_nodes'] = n_nodes
        obs['onnx_size'] = onnx_size
        obs['task'] = "SerializeToString"
        data.append(obs)

        parsed = parse(serialized)
        obs = measure_time(lambda: parse(serialized),
                           div_by_number=True, repeat=repeat)
        obs['size'] = size
        obs['n_nodes'] = n_nodes
        obs['onnx_size'] = onnx_size
        obs['task'] = "ParseFromString"
        data.append(obs)

        obs = measure_time(lambda: check_model(onx, full_check=False),
                           div_by_number=True, repeat=repeat)
        obs['size'] = size
        obs['n_nodes'] = n_nodes
        obs['onnx_size'] = onnx_size
        obs['task'] = "check_model"
        data.append(obs)

        if check_model_py is None:
            continue

        obs = measure_time(lambda: check_model_py(onx),
                           div_by_number=True, repeat=repeat)
        obs['size'] = size
        obs['n_nodes'] = n_nodes
        obs['onnx_size'] = onnx_size
        obs['task'] = "check_model_py"
        data.append(obs)
  0%|          | 0/7 [00:00<?, ?it/s]
 14%|#4        | 1/7 [00:13<01:18, 13.14s/it]
 29%|##8       | 2/7 [00:26<01:05, 13.19s/it]
 43%|####2     | 3/7 [00:39<00:53, 13.36s/it]
 57%|#####7    | 4/7 [00:56<00:44, 14.73s/it]
 71%|#######1  | 5/7 [01:10<00:28, 14.33s/it]
 86%|########5 | 6/7 [01:33<00:17, 17.43s/it]
100%|##########| 7/7 [02:07<00:00, 22.63s/it]
100%|##########| 7/7 [02:07<00:00, 18.17s/it]

time

df = pandas.DataFrame(data).sort_values(
    ['task', 'onnx_size', 'size', 'n_nodes'])
df[['task', 'onnx_size', 'size', 'n_nodes', 'average']]
task onnx_size size n_nodes average
1 ParseFromString 573 10 5 0.000071
5 ParseFromString 1108 10 10 0.000106
9 ParseFromString 2274 10 20 0.000171
13 ParseFromString 2383 100 5 0.000075
17 ParseFromString 4728 100 10 0.000108
... ... ... ... ... ...
67 check_model_py 8000770 200000 10 0.016459
59 check_model_py 8001596 100000 20 0.015613
79 check_model_py 12000770 300000 10 0.022685
71 check_model_py 16001596 200000 20 0.032270
83 check_model_py 24001596 300000 20 0.046195

84 rows × 5 columns



Summary#

df.to_excel("time.xlsx", index=False)
piv = df.pivot(index='onnx_size', columns='task', values='average')
piv
task ParseFromString SerializeToString check_model check_model_py
onnx_size
573 0.000071 0.000088 0.000144 0.001822
1108 0.000106 0.000134 0.000236 0.003213
2274 0.000171 0.000219 0.000393 0.006007
2383 0.000075 0.000091 0.000157 0.001828
4728 0.000108 0.000135 0.000240 0.003226
9514 0.000177 0.000221 0.000398 0.006025
20389 0.000081 0.000093 0.000166 0.001858
40739 0.000122 0.000140 0.000254 0.003278
81535 0.000205 0.000232 0.000430 0.006164
200399 0.000110 0.000115 0.000250 0.001993
400759 0.000268 0.000240 0.000643 0.003585
801575 0.000618 0.000569 0.001207 0.006630
2000404 0.000977 0.000903 0.001863 0.004334
4000405 0.001263 0.001242 0.002445 0.008527
4000770 0.001826 0.001700 0.003561 0.008047
6000405 0.001762 0.001706 0.003448 0.011656
8000770 0.002345 0.002294 0.004584 0.016459
8001596 0.003417 0.003161 0.006496 0.015613
12000770 0.003430 0.003452 0.006722 0.022685
16001596 0.004550 0.004430 0.009106 0.032270
24001596 0.006657 0.006511 0.012976 0.046195


Graph#

fig, ax = plt.subplots(1, 1)
piv.plot(title="Time processing of serialization functions\n"
               "lower better", ax=ax)
ax.set_xlabel("onnx size")
ax.set_ylabel("s")
Time processing of serialization functions lower better
Text(33.972222222222214, 0.5, 's')

Conclusion#

This graph shows that implementing check_model in python is much slower than the C++ version. However, protobuf prevents from sharing ModelProto from Python to C++ (see Python Updates) unless the python package is compiled with a specific setting (problably slower). A profiling shows that the code spends quite some time in function getattr().

ps = profile(lambda: check_model_py(onx))[0]
root, nodes = profile2graph(ps, clean_text=lambda x: x.split('/')[-1])
text = root.to_text()
print(text)

# plt.show()
__getattr__                                                  --  160  160 -- 0.00099 0.00280 -- _onnx_check_model.py:68:__getattr__ (__getattr__)
    <method 'endswith' of 'str' objects>                     --  160  160 -- 0.00020 0.00020 -- ~:0:<method 'endswith' of 'str' objects> (<method 'endswith' of 'str' objects>) +++
    <built-in method builtins.getattr>                       --  160  160 -- 0.00067 0.00067 -- ~:0:<built-in method builtins.getattr> (<built-in method builtins.getattr>) +++
    <built-in method builtins.hasattr>                       --  160  160 -- 0.00094 0.00094 -- ~:0:<built-in method builtins.hasattr> (<built-in method builtins.hasattr>) +++
__init__                                                     --    2    2 -- 0.00002 0.00002 -- _onnx_check_model.py:312:__init__ (__init__)
get_ir_version                                               --   21   21 -- 0.00004 0.00004 -- _onnx_check_model.py:326:get_ir_version (get_ir_version)
get_opset_imports                                            --   21   21 -- 0.00001 0.00001 -- _onnx_check_model.py:334:get_opset_imports (get_opset_imports)
set_opset_imports                                            --    2    2 -- 0.00000 0.00000 -- _onnx_check_model.py:338:set_opset_imports (set_opset_imports)
__init__                                                     --    2    3 -- 0.00001 0.00001 -- _onnx_check_model.py:377:__init__ (__init__)
    copy                                                     --    1    1 -- 0.00000 0.00001 -- _onnx_check_model.py:398:copy (copy)
        __init__                                             --    1    1 -- 0.00000 0.00000 -- _onnx_check_model.py:377:__init__ (__init__) +++
this_graph_has                                               --   81   81 -- 0.00008 0.00008 -- _onnx_check_model.py:388:this_graph_has (this_graph_has)
this_or_ancestor_graph_has                                   --   60   80 -- 0.00013 0.00021 -- _onnx_check_model.py:392:this_or_ancestor_graph_has (this_or_ancestor_graph_has)
    this_graph_has                                           --   80   80 -- 0.00008 0.00008 -- _onnx_check_model.py:388:this_graph_has (this_graph_has) +++
    this_or_ancestor_graph_has                               --   20   20 -- 0.00002 0.00003 -- _onnx_check_model.py:392:this_or_ancestor_graph_has (this_or_ancestor_graph_has) +++
_enforce_has_field                                           --   46   46 -- 0.00007 0.00024 -- _onnx_check_model.py:405:_enforce_has_field (_enforce_has_field)
    <built-in method builtins.hasattr>                       --   46   46 -- 0.00018 0.00018 -- ~:0:<built-in method builtins.hasattr> (<built-in method builtins.hasattr>) +++
_enforce_non_empty_field                                     --   23   23 -- 0.00003 0.00009 -- _onnx_check_model.py:417:_enforce_non_empty_field (_enforce_non_empty_field)
    <built-in method builtins.getattr>                       --   23   23 -- 0.00006 0.00006 -- ~:0:<built-in method builtins.getattr> (<built-in method builtins.getattr>) +++
<lambda>                                                     --    1    1 -- 0.00000 0.04924 -- plot_benchmark_onnx_serialize.py:150:<lambda> (<lambda>)
    check_model                                              --    1    1 -- 0.00002 0.04923 -- _onnx_check_model.py:1266:check_model (check_model)
        __init__                                             --    1    1 -- 0.00001 0.00001 -- _onnx_check_model.py:312:__init__ (__init__) +++
        _check_model                                         --    1    1 -- 0.00008 0.04920 -- _onnx_check_model.py:1223:_check_model (_check_model)
            get_ir_version                                   --    1    1 -- 0.00000 0.00000 -- _onnx_check_model.py:326:get_ir_version (get_ir_version) +++
            set_ir_version                                   --    1    1 -- 0.00000 0.00000 -- _onnx_check_model.py:330:set_ir_version (set_ir_version)
            set_opset_imports                                --    1    1 -- 0.00000 0.00000 -- _onnx_check_model.py:338:set_opset_imports (set_opset_imports) +++
            __init__                                         --    1    1 -- 0.00000 0.00000 -- _onnx_check_model.py:377:__init__ (__init__) +++
            _check_graph                                     --    1    1 -- 0.00149 0.04907 -- _onnx_check_model.py:962:_check_graph (_check_graph)
                get_ir_version                               --   20   20 -- 0.00004 0.00004 -- _onnx_check_model.py:326:get_ir_version (get_ir_version) +++
                __init__                                     --    1    1 -- 0.00000 0.00001 -- _onnx_check_model.py:377:__init__ (__init__) +++
                add                                          --   41   41 -- 0.00014 0.00018 -- _onnx_check_model.py:384:add (add)
                    <method 'add' of 'set' objects>          --   41   41 -- 0.00004 0.00004 -- ~:0:<method 'add' of 'set' objects> (<method 'add' of 'set' objects>) +++
                this_graph_has                               --    1    1 -- 0.00000 0.00000 -- _onnx_check_model.py:388:this_graph_has (this_graph_has) +++
                this_or_ancestor_graph_has                   --   60   60 -- 0.00011 0.00021 -- _onnx_check_model.py:392:this_or_ancestor_graph_has (this_or_ancestor_graph_has) +++
                _enforce_has_field                           --   20   20 -- 0.00004 0.00016 -- _onnx_check_model.py:405:_enforce_has_field (_enforce_has_field) +++
                _enforce_non_empty_field                     --    1    1 -- 0.00001 0.00001 -- _onnx_check_model.py:417:_enforce_non_empty_field (_enforce_non_empty_field) +++
                _check_value_info                            --    2    2 -- 0.00016 0.00049 -- _onnx_check_model.py:423:_check_value_info (_check_value_info)
                    is_main_graph                            --    2    2 -- 0.00000 0.00000 -- _onnx_check_model.py:342:is_main_graph (is_main_graph)
                    _enforce_has_field                       --    6    6 -- 0.00001 0.00003 -- _onnx_check_model.py:405:_enforce_has_field (_enforce_has_field) +++
                    _enforce_non_empty_field                 --    2    2 -- 0.00000 0.00001 -- _onnx_check_model.py:417:_enforce_non_empty_field (_enforce_non_empty_field) +++
                    <method 'ByteSize' ...CMessage' objects> --   12   12 -- 0.00003 0.00003 -- ~:0:<method 'ByteSize' of 'google.protobuf.pyext._message.CMessage' objects> (<method 'ByteSize' of 'google.protobuf.pyext._message.CMessage' objects>)
                    <method 'endswith' of 'str' objects>     --  138  138 -- 0.00006 0.00006 -- ~:0:<method 'endswith' of 'str' objects> (<method 'endswith' of 'str' objects>) +++
                    <built-in method builtins.dir>           --    2    2 -- 0.00014 0.00014 -- ~:0:<built-in method builtins.dir> (<built-in method builtins.dir>)
                    <built-in method builtins.getattr>       --   12   12 -- 0.00005 0.00005 -- ~:0:<built-in method builtins.getattr> (<built-in method builtins.getattr>) +++
                _check_tensor                                --   20   20 -- 0.01124 0.04066 -- _onnx_check_model.py:484:_check_tensor (_check_tensor)
                    _enforce_has_field                       --   20   20 -- 0.00002 0.00005 -- _onnx_check_model.py:405:_enforce_has_field (_enforce_has_field) +++
                    _check_data_field                        --  100  100 -- 0.00043 0.01959 -- _onnx_check_model.py:465:_check_data_field (_check_data_field)
                        <built-in method builtins.getattr>   --  120  120 -- 0.01901 0.01901 -- ~:0:<built-in method builtins.getattr> (<built-in method builtins.getattr>) +++
                        <built-in method builtins.len>       --  100  100 -- 0.00015 0.00015 -- ~:0:<built-in method builtins.len> (<built-in method builtins.len>) +++
                    <built-in method builtins.hasattr>       --   40   40 -- 0.00971 0.00971 -- ~:0:<built-in method builtins.hasattr> (<built-in method builtins.hasattr>) +++
                    <built-in method builtins.len>           --   20   20 -- 0.00008 0.00008 -- ~:0:<built-in method builtins.len> (<built-in method builtins.len>) +++
                _check_node                                  --   20   20 -- 0.00061 0.00579 -- _onnx_check_model.py:911:_check_node (_check_node)
                    __getattr__                              --   20   20 -- 0.00013 0.00031 -- _onnx_check_model.py:68:__getattr__ (__getattr__) +++
                    verify                                   --   20   20 -- 0.00141 0.00415 -- _onnx_check_model.py:83:verify (verify)
                        __init__                             --   20   20 -- 0.00006 0.00006 -- _onnx_check_model.py:31:__init__ (__init__)
                        __getattr__                          --  140  140 -- 0.00086 0.00248 -- _onnx_check_model.py:68:__getattr__ (__getattr__) +++
                        num_inputs_allowed                   --   20   20 -- 0.00001 0.00001 -- _onnx_check_model.py:73:num_inputs_allowed (num_inputs_allowed)
                        num_outputs_allowed                  --   20   20 -- 0.00001 0.00001 -- _onnx_check_model.py:78:num_outputs_allowed (num_outputs_allowed)
                        <built-in method builtins.len>       --  200  200 -- 0.00017 0.00017 -- ~:0:<built-in method builtins.len> (<built-in method builtins.len>) +++
                    GetSchema                                --   20   20 -- 0.00003 0.00055 -- _onnx_check_model.py:302:GetSchema (GetSchema)
                        get_schema                           --   20   20 -- 0.00008 0.00052 -- _onnx_check_model.py:295:get_schema (get_schema)
                            __init__                         --   20   20 -- 0.00003 0.00003 -- _onnx_check_model.py:65:__init__ (__init__)
                            <built-in metho...fs.get_schema> --   20   20 -- 0.00041 0.00041 -- ~:0:<built-in method onnx.onnx_cpp2py_export.defs.get_schema> (<built-in method onnx.onnx_cpp2py_export.defs.get_schema>)
                    get_opset_imports                        --   20   20 -- 0.00001 0.00001 -- _onnx_check_model.py:334:get_opset_imports (get_opset_imports) +++
                    get_schema_registry                      --   20   20 -- 0.00001 0.00001 -- _onnx_check_model.py:354:get_schema_registry (get_schema_registry)
                    _enforce_non_empty_field                 --   20   20 -- 0.00002 0.00007 -- _onnx_check_model.py:417:_enforce_non_empty_field (_enforce_non_empty_field) +++
                    check_is_experimental_op                 --   20   20 -- 0.00007 0.00007 -- _onnx_check_model.py:1296:check_is_experimental_op (check_is_experimental_op)
                <method 'append' of 'list' objects>          --   20   20 -- 0.00001 0.00001 -- ~:0:<method 'append' of 'list' objects> (<method 'append' of 'list' objects>)
                <method 'add' of 'set' objects>              --   20   20 -- 0.00002 0.00002 -- ~:0:<method 'add' of 'set' objects> (<method 'add' of 'set' objects>) +++
            _check_model_local_functions                     --    1    1 -- 0.00002 0.00004 -- _onnx_check_model.py:1122:_check_model_local_functions (_check_model_local_functions)
                __init__                                     --    1    1 -- 0.00001 0.00001 -- _onnx_check_model.py:312:__init__ (__init__) +++
                get_opset_imports                            --    1    1 -- 0.00000 0.00000 -- _onnx_check_model.py:334:get_opset_imports (get_opset_imports) +++
                set_opset_imports                            --    1    1 -- 0.00000 0.00000 -- _onnx_check_model.py:338:set_opset_imports (set_opset_imports) +++
<built-in method builtins.len>                               --  321  321 -- 0.00041 0.00041 -- ~:0:<built-in method builtins.len> (<built-in method builtins.len>)
<method 'add' of 'set' objects>                              --   61   61 -- 0.00006 0.00006 -- ~:0:<method 'add' of 'set' objects> (<method 'add' of 'set' objects>)
<built-in method builtins.hasattr>                           --  246  246 -- 0.01083 0.01083 -- ~:0:<built-in method builtins.hasattr> (<built-in method builtins.hasattr>)
<method 'endswith' of 'str' objects>                         --  298  298 -- 0.00026 0.00026 -- ~:0:<method 'endswith' of 'str' objects> (<method 'endswith' of 'str' objects>)
<built-in method builtins.getattr>                           --  315  315 -- 0.01978 0.01978 -- ~:0:<built-in method builtins.getattr> (<built-in method builtins.getattr>)

Total running time of the script: ( 2 minutes 9.090 seconds)

Gallery generated by Sphinx-Gallery