Note
Click here to download the full example code
First examples with onnxruntime#
Example First examples with onnx-array-api defines a custom
loss and then executes it with class
onnx.reference.ReferenceEvaluator
.
Next example replaces it with onnxruntime.
Example#
import numpy as np
from onnx_array_api.npx import absolute, jit_onnx
from onnx_array_api.ort.ort_tensors import JitOrtTensor, OrtTensor
def l1_loss(x, y):
return absolute(x - y).sum()
def l2_loss(x, y):
return ((x - y) ** 2).sum()
def myloss(x, y):
l1 = l1_loss(x[:, 0], y[:, 0])
l2 = l2_loss(x[:, 1], y[:, 1])
return l1 + l2
ort_myloss = jit_onnx(myloss, JitOrtTensor, target_opsets={"": 17}, ir_version=8)
x = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32)
y = np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32)
xort = OrtTensor.from_array(x)
yort = OrtTensor.from_array(y)
res = ort_myloss(xort, yort)
print(res.numpy())
0.042
Profiling#
from pyquickhelper.pycode.profiling import profile, profile2graph
x = np.random.randn(10000, 2).astype(np.float32)
y = np.random.randn(10000, 2).astype(np.float32)
xort = OrtTensor.from_array(x)
yort = OrtTensor.from_array(y)
def loop_ort(n):
for _ in range(n):
ort_myloss(xort, yort)
def loop_numpy(n):
for _ in range(n):
myloss(x, y)
def loop(n=1000):
loop_numpy(n)
loop_ort(n)
ps = profile(loop)[0]
root, nodes = profile2graph(ps, clean_text=lambda x: x.split("/")[-1])
text = root.to_text()
print(text)
var -- 2000 2000 -- 0.01854 0.12970 -- npx_core_api.py:44:var (var)
__init__ -- 2000 2000 -- 0.08160 0.11116 -- npx_var.py:273:__init__ (__init__) +++
info -- 4000 4000 -- 0.04804 0.06999 -- npx_jit_eager.py:52:info (info)
info -- 4000 4000 -- 0.01045 0.01777 -- __init__.py:1424:info (info)
isEnabledFor -- 4000 4000 -- 0.00732 0.00732 -- __init__.py:1677:isEnabledFor (isEnabledFor)
<built-in method builtins.len> -- 4000 4000 -- 0.00419 0.00419 -- ~:0:<built-in method builtins.len> (<built-in method builtins.len>) +++
get_cst_var -- 2000 2000 -- 0.03444 0.04732 -- npx_var.py:202:get_cst_var (get_cst_var)
parent -- 2000 2000 -- 0.00892 0.01288 -- <frozen importlib._bootstrap>:398:parent (parent)
<method 'rpartition' of 'str' objects> -- 2000 2000 -- 0.00396 0.00396 -- ~:0:<method 'rpartition' of 'str' objects> (<method 'rpartition' of 'str' objects>)
__init__ -- 3000 3000 -- 0.12511 0.17645 -- npx_var.py:273:__init__ (__init__)
__init__ -- 3000 3000 -- 0.00529 0.00529 -- npx_var.py:267:__init__ (__init__)
<listcomp> -- 3000 3000 -- 0.00512 0.00512 -- npx_var.py:330:<listcomp> (<listcomp>)
self_var -- 2000 2000 -- 0.00332 0.00486 -- npx_var.py:350:self_var (self_var) +++
<method 'ravel' of 'numpy.ndarray' objects> -- 1000 1000 -- 0.00263 0.00263 -- ~:0:<method 'ravel' of 'numpy.ndarray' objects> (<method 'ravel' of 'numpy.ndarray' objects>)
<method 'items' of 'dict' objects> -- 3000 3000 -- 0.00230 0.00230 -- ~:0:<method 'items' of 'dict' objects> (<method 'items' of 'dict' objects>) +++
<built-in method builtins.hasattr> -- 3000 3000 -- 0.00217 0.00217 -- ~:0:<built-in method builtins.hasattr> (<built-in method builtins.hasattr>) +++
<built-in method builtins.isinstance> -- 23000 23000 -- 0.02518 0.02518 -- ~:0:<built-in method builtins.isinstance> (<built-in method builtins.isinstance>) +++
<built-in method builtins.len> -- 6000 6000 -- 0.00379 0.00379 -- ~:0:<built-in method builtins.len> (<built-in method builtins.len>) +++
self_var -- 4000 4000 -- 0.00722 0.01062 -- npx_var.py:350:self_var (self_var)
<built-in method builtins.hasattr> -- 4000 4000 -- 0.00340 0.00340 -- ~:0:<built-in method builtins.hasattr> (<built-in method builtins.hasattr>) +++
__init__ -- 3000 3000 -- 0.01239 0.02015 -- ort_tensors.py:126:__init__ (__init__)
<built-in method builtins.isinstance> -- 5000 5000 -- 0.00776 0.00776 -- ~:0:<built-in method builtins.isinstance> (<built-in method builtins.isinstance>) +++
value -- 3000 3000 -- 0.00265 0.00265 -- ort_tensors.py:149:value (value)
loop -- 1 1 -- 0.00001 3.46192 -- plot_onnxruntime.py:69:loop (loop)
loop_ort -- 1 1 -- 0.01621 2.88319 -- plot_onnxruntime.py:59:loop_ort (loop_ort)
__call__ -- 1000 1000 -- 0.01715 2.86698 -- npx_jit_eager.py:419:__call__ (__call__)
info -- 2000 2000 -- 0.02059 0.03070 -- npx_jit_eager.py:52:info (info) +++
cast_to_tensor_class -- 1000 1000 -- 0.01698 0.03308 -- npx_jit_eager.py:266:cast_to_tensor_class (cast_to_tensor_class)
__init__ -- 2000 2000 -- 0.00784 0.01450 -- ort_tensors.py:126:__init__ (__init__) +++
<method 'append' of 'list' objects> -- 2000 2000 -- 0.00161 0.00161 -- ~:0:<method 'append' of 'list' objects> (<method 'append' of 'list' objects>) +++
cast_from_tensor_class -- 1000 1000 -- 0.00718 0.01077 -- npx_jit_eager.py:283:cast_from_tensor_class (cast_from_tensor_class)
value -- 1000 1000 -- 0.00102 0.00102 -- ort_tensors.py:149:value (value) +++
<built-in method builtins.isinstance> -- 1000 1000 -- 0.00194 0.00194 -- ~:0:<built-in method builtins.isinstance> (<built-in method builtins.isinstance>) +++
<built-in method builtins.len> -- 1000 1000 -- 0.00063 0.00063 -- ~:0:<built-in method builtins.len> (<built-in method builtins.len>) +++
jit_call -- 1000 1000 -- 0.02660 2.77529 -- npx_jit_eager.py:327:jit_call (jit_call)
info -- 2000 2000 -- 0.02745 0.03929 -- npx_jit_eager.py:52:info (info) +++
make_key -- 1000 1000 -- 0.01808 0.06576 -- npx_jit_eager.py:123:make_key (make_key)
key -- 2000 2000 -- 0.01190 0.04247 -- ort_tensors.py:144:key (key)
shape -- 2000 2000 -- 0.01381 0.01381 -- ort_tensors.py:134:shape (shape)
dtype -- 2000 2000 -- 0.01518 0.01518 -- ort_tensors.py:139:dtype (dtype)
<built-in method builtins.len> -- 2000 2000 -- 0.00159 0.00159 -- ~:0:<built-in method builtins.len> (<built-in method builtins.len>) +++
<method 'append' of 'list' objects> -- 2000 2000 -- 0.00127 0.00127 -- ~:0:<method 'append' of 'list' objects> (<method 'append' of 'list' objects>) +++
<built-in method builtins.isinstance> -- 2000 2000 -- 0.00395 0.00395 -- ~:0:<built-in method builtins.isinstance> (<built-in method builtins.isinstance>) +++
move_input_to_kwargs -- 1000 1000 -- 0.00297 0.00372 -- npx_jit_eager.py:298:move_input_to_kwargs (move_input_to_kwargs)
<built-in method builtins.len> -- 1000 1000 -- 0.00076 0.00076 -- ~:0:<built-in method builtins.len> (<built-in method builtins.len>) +++
run -- 1000 1000 -- 2.63138 2.63991 -- ort_tensors.py:106:run (run)
__init__ -- 1000 1000 -- 0.00455 0.00566 -- ort_tensors.py:126:__init__ (__init__) +++
value -- 2000 2000 -- 0.00163 0.00163 -- ort_tensors.py:149:value (value) +++
<built-in method builtins.len> -- 2000 2000 -- 0.00125 0.00125 -- ~:0:<built-in method builtins.len> (<built-in method builtins.len>) +++
loop_numpy -- 1 1 -- 0.00321 0.57872 -- plot_onnxruntime.py:64:loop_numpy (loop_numpy)
myloss -- 1000 1000 -- 0.02229 0.57552 -- plot_onnxruntime.py:31:myloss (myloss)
__add__ -- 1000 1000 -- 0.00378 0.11703 -- npx_var.py:604:__add__ (__add__)
_binary_op -- 1000 1000 -- 0.01473 0.11325 -- npx_var.py:574:_binary_op (_binary_op)
var -- 1000 1000 -- 0.00891 0.06937 -- npx_core_api.py:44:var (var) +++
get_cst_var -- 1000 1000 -- 0.01738 0.02394 -- npx_var.py:202:get_cst_var (get_cst_var) +++
self_var -- 1000 1000 -- 0.00188 0.00278 -- npx_var.py:350:self_var (self_var) +++
<built-in method builtins.isinstance> -- 1000 1000 -- 0.00243 0.00243 -- ~:0:<built-in method builtins.isinstance> (<built-in method builtins.isinstance>) +++
l1_loss -- 1000 1000 -- 0.06004 0.30350 -- plot_onnxruntime.py:23:l1_loss (l1_loss)
wrapper -- 1000 1000 -- 0.05824 0.14329 -- npx_core_api.py:120:wrapper (wrapper)
annotation -- 1000 1000 -- 0.00089 0.00089 -- inspect.py:2573:annotation (annotation)
kind -- 2000 2000 -- 0.00182 0.00182 -- inspect.py:2577:kind (kind)
parameters -- 2000 2000 -- 0.00224 0.00224 -- inspect.py:2882:parameters (parameters)
return_annotation -- 1000 1000 -- 0.00123 0.00123 -- inspect.py:2886:return_annotation (return_annotation)
__init__ -- 1000 1000 -- 0.04351 0.06529 -- npx_var.py:273:__init__ (__init__) +++
<method 'items' of ...ingproxy' objects> -- 1000 1000 -- 0.00177 0.00177 -- ~:0:<method 'items' of 'mappingproxy' objects> (<method 'items' of 'mappingproxy' objects>)
<method 'append' of 'list' objects> -- 1000 1000 -- 0.00106 0.00106 -- ~:0:<method 'append' of 'list' objects> (<method 'append' of 'list' objects>) +++
<method 'items' of 'dict' objects> -- 1000 1000 -- 0.00068 0.00068 -- ~:0:<method 'items' of 'dict' objects> (<method 'items' of 'dict' objects>) +++
<built-in method builtins.any> -- 1000 1000 -- 0.00259 0.00591 -- ~:0:<built-in method builtins.any> (<built-in method builtins.any>)
<lambda> -- 1000 1000 -- 0.00164 0.00332 -- npx_core_api.py:121:<lambda> (<lambda>)
<built-in metho...ns.isinstance> -- 1000 1000 -- 0.00167 0.00167 -- ~:0:<built-in method builtins.isinstance> (<built-in method builtins.isinstance>) +++
<built-in method builtins.isinstance> -- 2000 2000 -- 0.00202 0.00202 -- ~:0:<built-in method builtins.isinstance> (<built-in method builtins.isinstance>) +++
<built-in method builtins.issubclass> -- 2000 2000 -- 0.00152 0.00152 -- ~:0:<built-in method builtins.issubclass> (<built-in method builtins.issubclass>)
<built-in method builtins.len> -- 1000 1000 -- 0.00062 0.00062 -- ~:0:<built-in method builtins.len> (<built-in method builtins.len>) +++
sum -- 1000 1000 -- 0.00426 0.10016 -- npx_var.py:878:sum (sum)
reduce_function -- 1000 1000 -- 0.00922 0.09590 -- npx_var.py:859:reduce_function (reduce_function)
var -- 1000 1000 -- 0.00963 0.06033 -- npx_core_api.py:44:var (var) +++
get_cst_var -- 1000 1000 -- 0.01706 0.02338 -- npx_var.py:202:get_cst_var (get_cst_var) +++
self_var -- 1000 1000 -- 0.00202 0.00298 -- npx_var.py:350:self_var (self_var) +++
l2_loss -- 1000 1000 -- 0.08354 0.13270 -- plot_onnxruntime.py:27:l2_loss (l2_loss)
<method 'sum' of 'numpy.ndarray' objects> -- 1000 1000 -- 0.00528 0.04916 -- ~:0:<method 'sum' of 'numpy.ndarray' objects> (<method 'sum' of 'numpy.ndarray' objects>)
_sum -- 1000 1000 -- 0.00246 0.04389 -- _methods.py:47:_sum (_sum)
<method 'reduce' ....ufunc' objects> -- 1000 1000 -- 0.04143 0.04143 -- ~:0:<method 'reduce' of 'numpy.ufunc' objects> (<method 'reduce' of 'numpy.ufunc' objects>)
<built-in method builtins.len> -- 17000 17000 -- 0.01283 0.01283 -- ~:0:<built-in method builtins.len> (<built-in method builtins.len>)
<method 'append' of 'list' objects> -- 5000 5000 -- 0.00393 0.00393 -- ~:0:<method 'append' of 'list' objects> (<method 'append' of 'list' objects>)
<built-in method builtins.isinstance> -- 35000 35000 -- 0.04495 0.04495 -- ~:0:<built-in method builtins.isinstance> (<built-in method builtins.isinstance>)
<method 'items' of 'dict' objects> -- 4000 4000 -- 0.00297 0.00297 -- ~:0:<method 'items' of 'dict' objects> (<method 'items' of 'dict' objects>)
<built-in method builtins.hasattr> -- 7000 7000 -- 0.00557 0.00557 -- ~:0:<built-in method builtins.hasattr> (<built-in method builtins.hasattr>)
Benchmark#
from pandas import DataFrame
from tqdm import tqdm
from onnx_array_api.ext_test_case import measure_time
data = []
for n in tqdm([1, 10, 100, 1000, 10000, 100000]):
x = np.random.randn(n, 2).astype(np.float32)
y = np.random.randn(n, 2).astype(np.float32)
obs = measure_time(lambda: myloss(x, y))
obs["name"] = "numpy"
obs["n"] = n
data.append(obs)
xort = OrtTensor.from_array(x)
yort = OrtTensor.from_array(y)
obs = measure_time(lambda: ort_myloss(xort, yort))
obs["name"] = "ort"
obs["n"] = n
data.append(obs)
df = DataFrame(data)
piv = df.pivot(index="n", columns="name", values="average")
piv
0%| | 0/6 [00:00<?, ?it/s]
17%|#6 | 1/6 [00:00<00:02, 2.31it/s]
33%|###3 | 2/6 [00:00<00:01, 2.30it/s]
50%|##### | 3/6 [00:01<00:01, 2.22it/s]
67%|######6 | 4/6 [00:01<00:01, 2.00it/s]
83%|########3 | 5/6 [00:03<00:00, 1.11it/s]
100%|##########| 6/6 [00:15<00:00, 4.73s/it]
100%|##########| 6/6 [00:15<00:00, 2.62s/it]
Plots#
import matplotlib.pyplot as plt
fig, ax = plt.subplots(1, 2, figsize=(12, 4))
piv.plot(
title="Comparison between numpy and onnxruntime", logx=True, logy=True, ax=ax[0]
)
piv["ort/numpy"] = piv["ort"] / piv["numpy"]
piv["ort/numpy"].plot(title="Ratio ort/numpy", logx=True, ax=ax[1])
fig.savefig("plot_onnxruntime.png")
findfont: Font family ['STIXGeneral'] not found. Falling back to DejaVu Sans.
findfont: Font family ['STIXGeneral'] not found. Falling back to DejaVu Sans.
findfont: Font family ['STIXGeneral'] not found. Falling back to DejaVu Sans.
findfont: Font family ['STIXNonUnicode'] not found. Falling back to DejaVu Sans.
findfont: Font family ['STIXNonUnicode'] not found. Falling back to DejaVu Sans.
findfont: Font family ['STIXNonUnicode'] not found. Falling back to DejaVu Sans.
findfont: Font family ['STIXSizeOneSym'] not found. Falling back to DejaVu Sans.
findfont: Font family ['STIXSizeTwoSym'] not found. Falling back to DejaVu Sans.
findfont: Font family ['STIXSizeThreeSym'] not found. Falling back to DejaVu Sans.
findfont: Font family ['STIXSizeFourSym'] not found. Falling back to DejaVu Sans.
findfont: Font family ['STIXSizeFiveSym'] not found. Falling back to DejaVu Sans.
findfont: Font family ['cmsy10'] not found. Falling back to DejaVu Sans.
findfont: Font family ['cmr10'] not found. Falling back to DejaVu Sans.
findfont: Font family ['cmtt10'] not found. Falling back to DejaVu Sans.
findfont: Font family ['cmmi10'] not found. Falling back to DejaVu Sans.
findfont: Font family ['cmb10'] not found. Falling back to DejaVu Sans.
findfont: Font family ['cmss10'] not found. Falling back to DejaVu Sans.
findfont: Font family ['cmex10'] not found. Falling back to DejaVu Sans.
findfont: Font family ['DejaVu Sans Display'] not found. Falling back to DejaVu Sans.
Total running time of the script: ( 0 minutes 25.524 seconds)