Source code for mlprodict.testing.test_utils.utils_backend_common_compare

"""
Inspired from skl2onnx, handles two backends.


:githublink:`%|py|5`
"""
import numpy
import onnx
import pandas
from onnxruntime.capi.onnxruntime_pybind11_state import (  # pylint: disable=E0611
    InvalidArgument as OrtInvalidArgument)
from .utils_backend_common import (
    load_data_and_model, extract_options,
    ExpectedAssertionError, OnnxBackendAssertionError,
    OnnxRuntimeMissingNewOnnxOperatorException,
    _compare_expected, _create_column)


[docs]def compare_runtime_session( # pylint: disable=R0912 cls_session, test, decimal=5, options=None, verbose=False, context=None, comparable_outputs=None, intermediate_steps=False, classes=None, disable_optimisation=False): """ The function compares the expected output (computed with the model before being converted to ONNX) and the ONNX output produced with module :epkg:`onnxruntime` or :epkg:`mlprodict`. :param cls_session: inference session instance (like :class:`OnnxInference <mlprodict.onnxrt.onnx_inference.OnnxInference>`) :param test: dictionary with the following keys: - *onnx*: onnx model (filename or object) - *expected*: expected output (filename pkl or object) - *data*: input data (filename pkl or object) :param decimal: precision of the comparison :param options: comparison options :param context: specifies custom operators :param verbose: in case of error, the function may print more information on the standard output :param comparable_outputs: compare only these outputs :param intermediate_steps: displays intermediate steps in case of an error :param classes: classes names (if option 'nocl' is used) :param disable_optimisation: disable optimisation the runtime may do :return: tuple (outut, lambda function to run the predictions) The function does not return anything but raises an error if the comparison failed. :githublink:`%|py|46` """ lambda_onnx = None if context is None: context = {} load = load_data_and_model(test, **context) if verbose: # pragma no cover print("[compare_runtime] test '{}' loaded".format(test['onnx'])) onx = test['onnx'] if options is None: if isinstance(onx, str): options = extract_options(onx) else: options = {} elif options is None: options = {} elif not isinstance(options, dict): raise TypeError( # pragma no cover "options must be a dictionary.") if verbose: # pragma no cover print("[compare_runtime] InferenceSession('{}')".format(onx)) runtime_options = dict(disable_optimisation=disable_optimisation) try: sess = cls_session(onx, runtime_options=runtime_options) except TypeError as e: # pragma: no cover raise TypeError( # pylint: disable=W0707 "Wrong signature for '{}'.".format(cls_session.__name__)) except ExpectedAssertionError as expe: # pragma no cover raise expe except Exception as e: # pylint: disable=W0703 if "CannotLoad" in options: # pragma no cover raise ExpectedAssertionError( # pylint: disable=W0707 "Unable to load onnx '{0}' due to\n{1}".format(onx, e)) else: # pragma no cover if verbose: # pragma no cover model = onnx.load(onx) smodel = "\nJSON ONNX\n" + str(model) else: smodel = "" if ("NOT_IMPLEMENTED : Could not find an implementation " "for the node" in str(e)): # onnxruntime does not implement a specific node yet. raise OnnxRuntimeMissingNewOnnxOperatorException( # pylint: disable=W0707 "{3} does not implement a new operator " "'{0}'\n{1}\nONNX\n{2}".format( onx, e, smodel, cls_session)) if "NOT_IMPLEMENTED : Failed to find kernel" in str(e): # onnxruntime does not implement a specific node yet # in the kernel included in onnxruntime. raise OnnxBackendAssertionError( # pylint: disable=W0707 "{3} misses a kernel for operator " "'{0}'\n{1}\nONNX\n{2}".format( onx, e, smodel, cls_session)) raise OnnxBackendAssertionError( # pylint: disable=W0707 "Unable to load onnx '{0}'\nONNX\n{1}\n{2}".format( onx, smodel, e)) input = load["data"] DF = options.pop('DF', False) if DF: inputs = {c: input[c].values for c in input.columns} for k in inputs: if inputs[k].dtype == numpy.float64: inputs[k] = inputs[k].astype(numpy.float32) inputs[k] = inputs[k].reshape((inputs[k].shape[0], 1)) else: if isinstance(input, dict): inputs = input elif isinstance(input, (list, numpy.ndarray, pandas.DataFrame)): inp = sess.get_inputs() outs = sess.get_outputs() if len(outs) == 0: raise OnnxBackendAssertionError( # pragma: no cover "Wrong number of outputs, onnx='{2}'".format(onx)) if len(inp) == len(input): inputs = {i.name: v for i, v in zip(inp, input)} elif len(inp) == 1: inputs = {inp[0].name: input} elif isinstance(input, numpy.ndarray): shape = sum(i.shape[1] if len(i.shape) == 2 else i.shape[0] for i in inp) if shape == input.shape[1]: inputs = {n.name: input[:, i] for i, n in enumerate(inp)} else: raise OnnxBackendAssertionError( # pragma: no cover "Wrong number of inputs onnx {0} != " "original shape {1}, onnx='{2}'" .format(len(inp), input.shape, onx)) elif isinstance(input, list): try: array_input = numpy.array(input) except Exception: # pragma no cover raise OnnxBackendAssertionError( # pylint: disable=W0707 "Wrong number of inputs onnx {0} != " "original {1}, onnx='{2}'" .format(len(inp), len(input), onx)) shape = sum(i.shape[1] for i in inp) if shape == array_input.shape[1]: inputs = {} c = 0 for i, n in enumerate(inp): d = c + n.shape[1] inputs[n.name] = _create_column( [row[c:d] for row in input], n.type) c = d else: raise OnnxBackendAssertionError( # pragma no cover "Wrong number of inputs onnx {0} != " "original shape {1}, onnx='{2}'*" .format(len(inp), array_input.shape, onx)) elif isinstance(input, pandas.DataFrame): try: array_input = numpy.array(input) except Exception: # pragma no cover raise OnnxBackendAssertionError( # pylint: disable=W0707 "Wrong number of inputs onnx {0} != " "original {1}, onnx='{2}'" .format(len(inp), len(input), onx)) shape = sum(i.shape[1] for i in inp) if shape == array_input.shape[1]: inputs = {} c = 0 for i, n in enumerate(inp): d = c + n.shape[1] inputs[n.name] = _create_column( input.iloc[:, c:d], n.type) c = d else: raise OnnxBackendAssertionError( # pragma no cover "Wrong number of inputs onnx {0}={1} columns != " "original shape {2}, onnx='{3}'*" .format(len(inp), shape, array_input.shape, onx)) else: raise OnnxBackendAssertionError( # pragma no cover "Wrong type of inputs onnx {0}, onnx='{1}'".format( type(input), onx)) else: raise OnnxBackendAssertionError( # pragma no cover "Dict or list is expected, not {0}".format(type(input))) for k in inputs: if isinstance(inputs[k], list): inputs[k] = numpy.array(inputs[k]) options.pop('SklCol', False) # unused here but in dump_data_and_model if verbose: # pragma no cover print("[compare_runtime] type(inputs)={} len={} names={}".format( type(input), len(inputs), list(sorted(inputs)))) if verbose: # pragma no cover if intermediate_steps: run_options = {'verbose': 3, 'fLOG': print} else: run_options = {'verbose': 2, 'fLOG': print} else: run_options = {} try: try: output = sess.run(None, inputs, **run_options) except TypeError: # pragma no cover output = sess.run(None, inputs) lambda_onnx = lambda: sess.run(None, inputs) # noqa if verbose: # pragma no cover import pprint pprint.pprint(output) except ExpectedAssertionError as expe: # pragma no cover raise expe except (RuntimeError, OrtInvalidArgument) as e: # pragma no cover if intermediate_steps: sess.run(None, inputs, verbose=3, fLOG=print) if "-Fail" in onx: raise ExpectedAssertionError( # pylint: disable=W0707 "{1} cannot compute the prediction for '{0}'". format(onx, cls_session)) else: if verbose: # pragma no cover model = onnx.load(onx) smodel = "\nJSON ONNX\n" + str(model) else: smodel = "" import pprint raise OnnxBackendAssertionError( # pylint: disable=W0707 "{4} cannot compute the predictions" " for '{0}' due to {1}{2}\n{3}" .format(onx, e, smodel, pprint.pformat(inputs), cls_session)) except Exception as e: # pragma no cover raise OnnxBackendAssertionError( # pylint: disable=W0707 "Unable to run onnx '{0}' due to {1}".format(onx, e)) if verbose: # pragma no cover print("[compare_runtime] done type={}".format(type(output))) output0 = output.copy() if comparable_outputs: cmp_exp = [load["expected"][o] for o in comparable_outputs] cmp_out = [output[o] for o in comparable_outputs] else: cmp_exp = load["expected"] cmp_out = output try: _compare_expected(cmp_exp, cmp_out, sess, onx, decimal=decimal, verbose=verbose, classes=classes, **options) except ExpectedAssertionError as expe: # pragma no cover raise expe except Exception as e: # pragma no cover if verbose: # pragma no cover model = onnx.load(onx) smodel = "\nJSON ONNX\n" + str(model) else: smodel = "" raise OnnxBackendAssertionError( # pylint: disable=W0707 "Model '{}' has discrepencies with cls='{}'.\n{}: {}{}".format( onx, sess.__class__.__name__, type(e), e, smodel)) return output0, lambda_onnx