"""
Measures time processing for ONNX models.
:githublink:`%|py|5`
"""
import pickle
import os
from onnxruntime.capi.onnxruntime_pybind11_state import Fail as OrtFail # pylint: disable=E0611
import sklearn
from .. import OnnxInference
from .validate_helper import default_time_kwargs, measure_time, _multiply_time_kwargs
from .validate_benchmark import make_n_rows
[docs]class SimplifiedOnnxInference:
"Simple wrapper around InferenceSession which imitates OnnxInference."
[docs] def __init__(self, ort):
from onnxruntime import InferenceSession
self.sess = InferenceSession(ort)
@property
def input_names(self):
"Returns InferenceSession input names."
return [_.name for _ in self.sess.get_inputs()]
[docs] def run(self, input):
"Calls InferenceSession.run."
return self.sess.run(None, input)
[docs]def enumerate_benchmark_replay(folder, runtime='python', time_kwargs=None,
skip_long_test=True, time_kwargs_fact=None,
time_limit=4, verbose=1, fLOG=None):
"""
Replays a benchmark stored with function
:func:`enumerate_validated_operator_opsets
<mlprodict.onnxrt.validate.validate.enumerate_validated_operator_opsets>`
or command line :ref:`validate_runtime <l-cmd-validate_runtime>`.
Enumerates the results.
:param folder: folder where to find pickled files, all files must have
*pkl* or *pickle* extension
:param runtime: runtime or runtimes
:param time_kwargs: to define a more precise way to measure a model
:param skip_long_test: skips tests for high values of N if they seem too long
:param time_kwargs_fact: see :func:`_multiply_time_kwargs <mlprodict.onnxrt.validate.validate_helper._multiply_time_kwargs>`
:param time_limit: to skip the rest of the test after this limit (in second)
:param verbose: if >= 1, uses :epkg:`tqdm`
:param fLOG: logging function
:return: iterator on results
:githublink:`%|py|51`
"""
files = [_ for _ in os.listdir(folder) if _.endswith(
".pkl") or _.endswith("_.pickle")]
if len(files) == 0:
raise FileNotFoundError(
"Unable to find any file in folder '{}'.".format(folder))
if time_kwargs in (None, ''):
time_kwargs = default_time_kwargs()
if isinstance(runtime, str):
runtime = runtime.split(",")
loop = files
if verbose >= 1:
try:
from tqdm import tqdm
loop = tqdm(files)
except ImportError: # pragma: no cover
pass
for pkl in loop:
if "ERROR" in pkl:
# An error.
if verbose >= 2 and fLOG is not None: # pragma: no cover
fLOG( # pragma: no cover
"[enumerate_benchmark_replay] skip '{}'.".format(pkl))
continue # pragma: no cover
if verbose >= 2 and fLOG is not None:
fLOG("[enumerate_benchmark_replay] process '{}'.".format(pkl))
row = {}
with open(os.path.join(folder, pkl), 'rb') as f:
obj = pickle.load(f)
X_test = obj['X_test']
ort_test = obj['Xort_test']
onx = obj['onnx_bytes']
model = obj['skl_model']
tkw = _multiply_time_kwargs(time_kwargs, time_kwargs_fact, model)
row['folder'] = folder
row['filename'] = pkl
row['n_features'] = X_test.shape[1]
for key in ['assume_finite', 'conv_options',
'init_types', 'idtype', 'method_name', 'n_features',
'name', 'optim', 'opset', 'predict_kwargs',
'output_index', 'problem', 'scenario']:
row[key] = obj['obs_op'][key]
# 'bench-batch',
# 'bench-skl',
oinfs = {}
for rt in runtime:
if rt == 'onnxruntime':
try:
oinfs[rt] = SimplifiedOnnxInference(onx)
except (OrtFail, RuntimeError) as e: # pragma: no cover
row['ERROR'] = str(e)
oinfs[rt] = None
else:
try:
oinfs[rt] = OnnxInference(onx, runtime=rt)
except (OrtFail, RuntimeError) as e: # pragma: no cover
row['ERROR'] = str(e)
oinfs[rt] = None
for k, v in sorted(tkw.items()):
if verbose >= 3 and fLOG is not None:
fLOG( # pragma: no cover
"[enumerate_benchmark_replay] process n_rows={} - {}".format(k, v))
xt = make_n_rows(X_test, k)
number = v['number']
repeat = v['repeat']
meth = getattr(model, row['method_name'])
with sklearn.config_context(assume_finite=row['assume_finite']):
skl = measure_time(lambda x: meth(x), xt,
number=number, repeat=repeat,
div_by_number=True)
if verbose >= 4 and fLOG is not None:
fLOG( # pragma: no cover
"[enumerate_benchmark_replay] skl={}".format(skl))
row['%d-skl-details' % k] = skl
row['%d-skl' % k] = skl['average']
xto = make_n_rows(ort_test, k)
for rt in runtime:
oinf = oinfs[rt]
if oinf is None:
continue # pragma: no cover
if len(oinf.input_names) != 1:
raise NotImplementedError( # pragma: no cover
"This function only allows one input not {}".format(
len(oinf.input_names)))
name = oinf.input_names[0]
ort = measure_time(lambda x: oinf.run({name: x}), xto,
number=number, repeat=repeat,
div_by_number=True)
if verbose >= 4 and fLOG is not None:
fLOG( # pragma: no cover
"[enumerate_benchmark_replay] {}={}".format(rt, ort))
row['%d-%s-detail' % (k, rt)] = ort
row['%d-%s' % (k, rt)] = ort['average']
yield row