Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Helpers to validate python code. 

4""" 

5import pickle 

6import pprint 

7import numpy 

8from numpy.linalg import det as npy_det # pylint: disable=E0611 

9from scipy.spatial.distance import cdist # pylint: disable=E0611 

10from scipy.special import expit, erf # pylint: disable=E0611 

11from scipy.linalg import solve # pylint: disable=E0611 

12from ...tools.code_helper import make_callable 

13 

14 

15def _make_callable(fct, obj, code, gl, debug): 

16 """ 

17 Same function as @see fn make_callable but deals with 

18 function which an undefined number of arguments. 

19 """ 

20 def pyrt_Concat_(*inputs, axis=0): 

21 return numpy.concatenate(inputs, axis=axis) 

22 

23 if fct == "pyrt_Concat": 

24 return pyrt_Concat_ 

25 return make_callable(fct, obj, code, gl, debug) 

26 

27 

28def validate_python_inference(oinf, inputs, tolerance=0.): 

29 """ 

30 Validates the code produced by method :meth:`to_python 

31 <mlprodict.onnxrt.onnx_inference_exports.OnnxInferenceExport.to_python>`. 

32 The function compiles and executes the code 

33 given as an argument and compares the results to 

34 what *oinf* returns. This function is mostly used for 

35 unit testing purpose but it is not robust enough 

36 to handle all cases. 

37 

38 @param oinf @see cl OnnxInference 

39 @param inputs inputs as dictionary 

40 @param tolerance discrepencies must be below or equal to 

41 this theshold 

42 

43 The function fails if the expected output are not the same. 

44 """ 

45 from ..ops_cpu.op_argmax import _argmax 

46 from ..ops_cpu.op_argmin import _argmin 

47 from ..ops_cpu.op_celu import _vcelu1 

48 

49 cd = oinf.to_python() 

50 code = cd['onnx_pyrt_main.py'] 

51 

52 exp = oinf.run(inputs) 

53 if not isinstance(exp, dict): 

54 raise TypeError( # pragma: no cover 

55 "exp is not a dictionary by '{}'.".format(type(exp))) 

56 if len(exp) == 0: 

57 raise ValueError( # pragma: no cover 

58 "No result to compare.") 

59 inps = ['{0}={0}'.format(k) for k in sorted(inputs)] 

60 code += "\n".join(['', '', 'opi = OnnxPythonInference()', 

61 'res = opi.run(%s)' % ', '.join(inps)]) 

62 

63 cp = compile(code, "<string>", mode='exec') 

64 pyrt_fcts = [_ for _ in cp.co_names if _.startswith("pyrt_")] 

65 fcts_local = {} 

66 

67 gl = {'numpy': numpy, 'pickle': pickle, 

68 'expit': expit, 'erf': erf, 'cdist': cdist, 

69 '_argmax': _argmax, '_argmin': _argmin, 

70 '_vcelu1': _vcelu1, 'solve': solve, 

71 'fft': numpy.fft.fft, 'rfft': numpy.fft.rfft, 

72 'fft2': numpy.fft.fft2, 

73 'npy_det': npy_det, 'ndarray': numpy.ndarray} 

74 

75 for fct in pyrt_fcts: 

76 for obj in cp.co_consts: 

77 if isinstance(obj, str): 

78 continue 

79 sobj = str(obj) 

80 if '<string>' in sobj and fct in sobj: 

81 fcts_local[fct] = _make_callable(fct, obj, code, gl, False) 

82 

83 gl.update(fcts_local) 

84 loc = inputs 

85 try: 

86 exec(cp, gl, loc) # pylint: disable=W0122 

87 except (NameError, TypeError, SyntaxError, IndexError) as e: # pragma: no cover 

88 raise RuntimeError( 

89 "Unable to execute code\n-----\n{}".format(code)) from e 

90 

91 got = loc['res'] 

92 keys = list(sorted(exp)) 

93 if isinstance(got, numpy.ndarray) and len(keys) == 1: 

94 got = {keys[0]: got} 

95 

96 if not isinstance(got, dict): 

97 raise TypeError( # pragma: no cover 

98 "got is not a dictionary by '{}'\n--\n{}\n---\n{}.".format( 

99 type(got), dir(got), pprint.pformat(str(loc)))) 

100 if len(got) != len(exp): 

101 raise RuntimeError( # pragma: no cover 

102 "Different number of results.\nexp: {}\ngot: {}".format( 

103 ", ".join(sorted(exp)), ", ".join(sorted(got)))) 

104 

105 if keys != list(sorted(got)): 

106 raise RuntimeError( # pragma: no cover 

107 "Different result names.\nexp: {}\ngot: {}".format( 

108 ", ".join(sorted(exp)), ", ".join(sorted(got)))) 

109 

110 for k in keys: 

111 e = exp[k] 

112 g = got[k] 

113 if isinstance(e, numpy.ndarray): 

114 if e.shape != g.shape: 

115 raise ValueError( # pragma: no cover 

116 "Shapes are different {} != {}\n---\n{}\n{}.".format( 

117 e.shape, g.shape, e, g)) 

118 diff = 0 

119 for a, b in zip(e.ravel(), g.ravel()): 

120 if a == b: 

121 continue 

122 if (isinstance(a, float) and isinstance(b, float) and 

123 numpy.isnan(a) and numpy.isnan(b)): 

124 continue # pragma: no cover 

125 diff = max(diff, abs(a - b)) 

126 if diff > tolerance: 

127 raise ValueError( # pragma: no cover 

128 "Values are different (max diff={}>{})\n--EXP--\n{}\n--GOT--" 

129 "\n{}\n--\n{}".format(diff, tolerance, e, g, code)) 

130 else: 

131 raise NotImplementedError( # pragma: no cover 

132 "Unable to compare values of type '{}'.".format(type(e)))