Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Helpers to validate python code.
4"""
5import pickle
6import pprint
7import numpy
8from numpy.linalg import det as npy_det # pylint: disable=E0611
9from scipy.spatial.distance import cdist # pylint: disable=E0611
10from scipy.special import expit, erf # pylint: disable=E0611
11from scipy.linalg import solve # pylint: disable=E0611
12from ...tools.code_helper import make_callable
15def _make_callable(fct, obj, code, gl, debug):
16 """
17 Same function as @see fn make_callable but deals with
18 function which an undefined number of arguments.
19 """
20 def pyrt_Concat_(*inputs, axis=0):
21 return numpy.concatenate(inputs, axis=axis)
23 if fct == "pyrt_Concat":
24 return pyrt_Concat_
25 return make_callable(fct, obj, code, gl, debug)
28def validate_python_inference(oinf, inputs, tolerance=0.):
29 """
30 Validates the code produced by method :meth:`to_python
31 <mlprodict.onnxrt.onnx_inference_exports.OnnxInferenceExport.to_python>`.
32 The function compiles and executes the code
33 given as an argument and compares the results to
34 what *oinf* returns. This function is mostly used for
35 unit testing purpose but it is not robust enough
36 to handle all cases.
38 @param oinf @see cl OnnxInference
39 @param inputs inputs as dictionary
40 @param tolerance discrepencies must be below or equal to
41 this theshold
43 The function fails if the expected output are not the same.
44 """
45 from ..ops_cpu.op_argmax import _argmax
46 from ..ops_cpu.op_argmin import _argmin
47 from ..ops_cpu.op_celu import _vcelu1
49 cd = oinf.to_python()
50 code = cd['onnx_pyrt_main.py']
52 exp = oinf.run(inputs)
53 if not isinstance(exp, dict):
54 raise TypeError( # pragma: no cover
55 "exp is not a dictionary by '{}'.".format(type(exp)))
56 if len(exp) == 0:
57 raise ValueError( # pragma: no cover
58 "No result to compare.")
59 inps = ['{0}={0}'.format(k) for k in sorted(inputs)]
60 code += "\n".join(['', '', 'opi = OnnxPythonInference()',
61 'res = opi.run(%s)' % ', '.join(inps)])
63 cp = compile(code, "<string>", mode='exec')
64 pyrt_fcts = [_ for _ in cp.co_names if _.startswith("pyrt_")]
65 fcts_local = {}
67 gl = {'numpy': numpy, 'pickle': pickle,
68 'expit': expit, 'erf': erf, 'cdist': cdist,
69 '_argmax': _argmax, '_argmin': _argmin,
70 '_vcelu1': _vcelu1, 'solve': solve,
71 'fft': numpy.fft.fft, 'rfft': numpy.fft.rfft,
72 'fft2': numpy.fft.fft2,
73 'npy_det': npy_det, 'ndarray': numpy.ndarray}
75 for fct in pyrt_fcts:
76 for obj in cp.co_consts:
77 if isinstance(obj, str):
78 continue
79 sobj = str(obj)
80 if '<string>' in sobj and fct in sobj:
81 fcts_local[fct] = _make_callable(fct, obj, code, gl, False)
83 gl.update(fcts_local)
84 loc = inputs
85 try:
86 exec(cp, gl, loc) # pylint: disable=W0122
87 except (NameError, TypeError, SyntaxError, IndexError) as e: # pragma: no cover
88 raise RuntimeError(
89 "Unable to execute code\n-----\n{}".format(code)) from e
91 got = loc['res']
92 keys = list(sorted(exp))
93 if isinstance(got, numpy.ndarray) and len(keys) == 1:
94 got = {keys[0]: got}
96 if not isinstance(got, dict):
97 raise TypeError( # pragma: no cover
98 "got is not a dictionary by '{}'\n--\n{}\n---\n{}.".format(
99 type(got), dir(got), pprint.pformat(str(loc))))
100 if len(got) != len(exp):
101 raise RuntimeError( # pragma: no cover
102 "Different number of results.\nexp: {}\ngot: {}".format(
103 ", ".join(sorted(exp)), ", ".join(sorted(got))))
105 if keys != list(sorted(got)):
106 raise RuntimeError( # pragma: no cover
107 "Different result names.\nexp: {}\ngot: {}".format(
108 ", ".join(sorted(exp)), ", ".join(sorted(got))))
110 for k in keys:
111 e = exp[k]
112 g = got[k]
113 if isinstance(e, numpy.ndarray):
114 if e.shape != g.shape:
115 raise ValueError( # pragma: no cover
116 "Shapes are different {} != {}\n---\n{}\n{}.".format(
117 e.shape, g.shape, e, g))
118 diff = 0
119 for a, b in zip(e.ravel(), g.ravel()):
120 if a == b:
121 continue
122 if (isinstance(a, float) and isinstance(b, float) and
123 numpy.isnan(a) and numpy.isnan(b)):
124 continue # pragma: no cover
125 diff = max(diff, abs(a - b))
126 if diff > tolerance:
127 raise ValueError( # pragma: no cover
128 "Values are different (max diff={}>{})\n--EXP--\n{}\n--GOT--"
129 "\n{}\n--\n{}".format(diff, tolerance, e, g, code))
130 else:
131 raise NotImplementedError( # pragma: no cover
132 "Unable to compare values of type '{}'.".format(type(e)))