Source code for mlprodict.testing.test_utils.utils_backend_python
"""
Inspired from skl2onnx, handles two backends.
:githublink:`%|py|5`
"""
from ...onnxrt import OnnxInference
from .utils_backend_common_compare import compare_runtime_session
[docs]class MockVariableName:
"A string."
[docs] def __init__(self, name):
self.name = name
@property
def shape(self):
"returns shape"
raise NotImplementedError( # pragma: no cover
"No shape for '{}'.".format(self.name))
@property
def type(self):
"returns type"
raise NotImplementedError( # pragma: no cover
"No type for '{}'.".format(self.name))
[docs]class MockVariableNameShape(MockVariableName):
"A string and a shape."
[docs] def __init__(self, name, sh):
MockVariableName.__init__(self, name)
self._shape = sh
@property
def shape(self):
"returns shape"
return self._shape
[docs]class MockVariableNameShapeType(MockVariableNameShape):
"A string and a shape and a type."
[docs] def __init__(self, name, sh, stype):
MockVariableNameShape.__init__(self, name, sh)
self._stype = stype
@property
def type(self):
"returns type"
return self._stype
[docs]class OnnxInference2(OnnxInference):
"onnxruntime API"
[docs] def run(self, name, inputs, *args, **kwargs): # pylint: disable=W0221
"onnxruntime API"
res = OnnxInference.run(self, inputs, **kwargs)
if name is None:
return [res[n] for n in self.output_names]
if name in res: # pragma: no cover
return res[name]
raise RuntimeError( # pragma: no cover
"Unable to find output '{}'.".format(name))
[docs] def get_outputs(self):
"onnxruntime API"
return [MockVariableNameShape(*n) for n in self.output_names_shapes]
[docs] def run_in_scan(self, inputs, verbose=0, fLOG=None):
"Instance to run in operator scan."
return OnnxInference.run(self, inputs, verbose=verbose, fLOG=fLOG)
[docs]def compare_runtime(test, decimal=5, options=None,
verbose=False, context=None, comparable_outputs=None,
intermediate_steps=False, classes=None,
disable_optimisation=False):
"""
The function compares the expected output (computed with
the model before being converted to ONNX) and the ONNX output
produced with module :epkg:`onnxruntime` or :epkg:`mlprodict`.
:param test: dictionary with the following keys:
- *onnx*: onnx model (filename or object)
- *expected*: expected output (filename pkl or object)
- *data*: input data (filename pkl or object)
:param decimal: precision of the comparison
:param options: comparison options
:param context: specifies custom operators
:param verbose: in case of error, the function may print
more information on the standard output
:param comparable_outputs: compare only these outputs
:param intermediate_steps: displays intermediate steps
in case of an error
:param classes: classes names (if option 'nocl' is used)
:param disable_optimisation: disable optimisation the runtime may do
:return: tuple (outut, lambda function to run the predictions)
The function does not return anything but raises an error
if the comparison failed.
:githublink:`%|py|107`
"""
return compare_runtime_session(
OnnxInference2, test, decimal=decimal, options=options,
verbose=verbose, context=context,
comparable_outputs=comparable_outputs,
intermediate_steps=intermediate_steps,
classes=classes, disable_optimisation=disable_optimisation)