Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Helpers to compare executions. 

4""" 

5import copy 

6import numpy 

7from .validate_difference import measure_relative_difference 

8 

9 

10def _side_by_side_by_values_inputs(sess, inputs, i): 

11 if isinstance(sess, tuple) and inputs is None: 

12 new_sess, new_inputs = sess 

13 elif isinstance(inputs, list): 

14 new_sess = sess 

15 new_inputs = inputs[i] 

16 else: 

17 new_sess = sess 

18 new_inputs = copy.deepcopy(inputs) 

19 return new_sess, new_inputs 

20 

21 

22def side_by_side_by_values(sessions, *args, inputs=None, 

23 return_results=False, **kwargs): 

24 """ 

25 Compares the execution of two sessions. 

26 It calls method :meth:`OnnxInference.run 

27 <mlprodict.onnxrt.onnx_inference.OnnxInference.run>` 

28 with value ``intermediate=True`` and compares the results. 

29 

30 :param sessions: list of class @see cl OnnxInference 

31 :param inputs: inputs 

32 :param args: additional parameters for 

33 :meth:`OnnxInference.run 

34 <mlprodict.onnxrt.onnx_inference.OnnxInference.run` 

35 :param return_results: if True, returns the results as well. 

36 :param kwargs: additional parameters for 

37 :meth:`OnnxInference.run 

38 <mlprodict.onnxrt.onnx_inference.OnnxInference.run` 

39 :return: list of dictionaries 

40 

41 The first session is considered as the baseline. 

42 See notebook :ref:`onnxsbsrst` for an example. 

43 If *inputs* is None, the function assumes 

44 *sessions* is a list of *tuple(sessions, inputs)* 

45 because sometimes inputs must be different. 

46 

47 .. versionchanged:: 0.7 

48 Parameter *return_results* was added. The function 

49 returns the execution order when available. 

50 """ 

51 if not kwargs.get('intermediate', True): 

52 raise ValueError( # pragma: no cover 

53 "kwargs must not set intermediate to True") 

54 kwargs['intermediate'] = True 

55 verbose = kwargs.get('verbose', 0) 

56 fLOG = kwargs.get('fLOG', None) 

57 

58 # run 

59 results = [] 

60 orders = [] 

61 for i, sess in enumerate(sessions): 

62 if (hasattr(sess, 'runtime') and hasattr(sess, 'inplace') and 

63 sess.runtime in (None, 'python') and sess.inplace): 

64 raise ValueError( 

65 "You must disable the inplace mechanism in order to get " 

66 "true results. See OnnxInference constructor.") 

67 new_sess, new_inputs = _side_by_side_by_values_inputs(sess, inputs, i) 

68 if verbose > 0 and fLOG: 

69 fLOG( # pragma: no cover 

70 '[side_by_side_by_values] run session {}/{}'.format( 

71 i + 1, len(sessions))) 

72 res = new_sess.run(new_inputs, *args, **kwargs) 

73 order = new_sess.get_execution_order() 

74 results.append([(k, v) for k, v in res.items()]) 

75 orders.append(order) 

76 

77 # same number of results? 

78 rows = [] 

79 row = {"metric": "nb_results", 'step': -1} 

80 for i, res in enumerate(results): 

81 row["v[%d]" % i] = len(res) 

82 mnd = min(map(len, results)) 

83 mxd = max(map(len, results)) 

84 row['cmp'] = 'OK' if mnd == mxd else '!=' 

85 rows.append(row) 

86 

87 merged = merge_results(results) 

88 

89 # analysis 

90 for i in range(len(merged)): # pylint: disable=C0200 

91 for metric in ('rel-diff', 'abs-diff'): 

92 row = {'step': i} 

93 name, res_row = merged[i] 

94 row['name'] = name 

95 row['metric'] = metric 

96 

97 vals = [] 

98 for j, r in enumerate(res_row): 

99 order = orders[j] 

100 if order is not None: 

101 row['order[%d]' % j] = order.get( 

102 ('res', name), (numpy.nan, ))[0] 

103 row['value[%d]' % j] = r 

104 if hasattr(r, 'shape'): 

105 row['shape[%d]' % j] = r.shape 

106 

107 if j == 0: 

108 row['v[%d]' % j] = 0 

109 elif res_row[0] is not None and r is not None: 

110 v = measure_relative_difference( 

111 res_row[0], r, abs_diff=metric == 'abs-diff') 

112 row['v[%d]' % j] = v 

113 vals.append(v) 

114 

115 if len(vals) > 0: 

116 diff = max(vals) 

117 if diff < 1e-5: 

118 row['cmp'] = 'OK' 

119 elif diff < 0.0001: 

120 row['cmp'] = 'e<0.0001' # pragma: no cover 

121 elif diff < 0.001: 

122 row['cmp'] = 'e<0.001' # pragma: no cover 

123 elif diff < 0.01: 

124 row['cmp'] = 'e<0.01' # pragma: no cover 

125 elif diff < 0.1: 

126 row['cmp'] = 'e<0.1' # pragma: no cover 

127 else: 

128 row['cmp'] = "ERROR->=%1.1f" % diff 

129 

130 rows.append(row) 

131 if return_results: 

132 return rows, results 

133 return rows 

134 

135 

136def merge_results(results): 

137 """ 

138 Merges results by name. The first ones 

139 are used to keep the order. 

140 

141 :param results: results of intermediate variables 

142 :return: list of tuple 

143 """ 

144 # matrix of names 

145 rows = [(k, []) for k, _ in results[0]] 

146 positions = {k[0]: i for i, k in enumerate(rows)} 

147 todos = [] 

148 for result in results: 

149 todo = [] 

150 for row in rows: 

151 row[1].append(None) 

152 for i, (k, v) in enumerate(result): 

153 pos = positions.get(k, None) 

154 if pos is None: 

155 todo.append((i, k, v)) 

156 else: 

157 rows[pos][1][-1] = (v, i) 

158 todos.append(todo) 

159 

160 # left over 

161 if len(todos) > 0: 

162 for i, todo in enumerate(todos): 

163 if len(todo) == 0: 

164 continue 

165 for pos, name, val in todo: 

166 pos1 = pos + 1 

167 found = -1 

168 for ik, row in enumerate(rows): 

169 if row[1][i] is not None and row[1][i][1] == pos1: 

170 found = ik 

171 break 

172 vv = [None] * len(results) 

173 if found == -1: 

174 vv[i] = (val, len(rows)) 

175 rows.append((name, vv)) 

176 else: 

177 vv[i] = (val, pos) 

178 rows.insert(found, (name, vv)) 

179 

180 # final 

181 final = [] 

182 for row in rows: 

183 nrow = (row[0], [_ if _ is None else _[0] for _ in row[1]]) 

184 final.append(nrow) 

185 return final