Source code for mlprodict.testing.test_utils.utils_backend
"""
Inspired from :epkg:`sklearn-onnx`, handles two backends.
:githublink:`%|py|5`
"""
from .utils_backend_onnxruntime import compare_runtime as compare_runtime_ort
from .utils_backend_python import compare_runtime as compare_runtime_pyrt
[docs]def compare_backend(backend, test, decimal=5, options=None, verbose=False,
context=None, comparable_outputs=None,
intermediate_steps=False, classes=None,
disable_optimisation=False):
"""
The function compares the expected output (computed with
the model before being converted to ONNX) and the ONNX output.
:param backend: backend to use to run the comparison
:param test: dictionary with the following keys:
- *onnx*: onnx model (filename or object)
- *expected*: expected output (filename pkl or object)
- *data*: input data (filename pkl or object)
:param decimal: precision of the comparison
:param options: comparison options
:param context: specifies custom operators
:param comparable_outputs: compare only these outputs
:param verbose: in case of error, the function may print
more information on the standard output
:param intermediate_steps: displays intermediate steps
in case of an error
:param classes: classes names (if option 'nocl' is used)
:param disable_optimisation: disable optimisation onnxruntime
could do
The function does not return anything but raises an error
if the comparison failed.
:return: tuple (output, lambda function to call onnx predictions)
:githublink:`%|py|37`
"""
if backend == "onnxruntime":
return compare_runtime_ort(
test, decimal, options=options, verbose=verbose,
comparable_outputs=comparable_outputs,
intermediate_steps=False, classes=classes,
disable_optimisation=disable_optimisation)
if backend == "python":
return compare_runtime_pyrt(
test, decimal, options=options, verbose=verbose,
comparable_outputs=comparable_outputs,
intermediate_steps=intermediate_steps, classes=classes,
disable_optimisation=disable_optimisation)
raise ValueError( # pragma: no cover
"Does not support backend '{0}'.".format(backend))