Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Measures time processing for ONNX models. 

4""" 

5import numpy 

6from sklearn import __all__ as sklearn__all__, __version__ as sklearn_version 

7from ... import __version__ as ort_version 

8from .validate_helper import default_time_kwargs, measure_time 

9 

10 

11def make_n_rows(x, n, y=None): 

12 """ 

13 Multiplies or reduces the rows of x to get 

14 exactly *n* rows. 

15 

16 @param x matrix 

17 @param n number of rows 

18 @param y target (optional) 

19 @return new matrix or two new matrices if y is not None 

20 """ 

21 if n < x.shape[0]: 

22 if y is None: 

23 return x[:n].copy() 

24 return x[:n].copy(), y[:n].copy() 

25 if len(x.shape) < 2: 

26 r = numpy.empty((n, ), dtype=x.dtype) 

27 if y is not None: 

28 ry = numpy.empty((n, ), dtype=y.dtype) # pragma: no cover 

29 for i in range(0, n, x.shape[0]): 

30 end = min(i + x.shape[0], n) 

31 r[i: end] = x[0: end - i] 

32 if y is not None: 

33 ry[i: end] = y[0: end - i] # pragma: no cover 

34 else: 

35 r = numpy.empty((n, x.shape[1]), dtype=x.dtype) 

36 if y is not None: 

37 if len(y.shape) < 2: 

38 ry = numpy.empty((n, ), dtype=y.dtype) 

39 else: 

40 ry = numpy.empty((n, y.shape[1]), dtype=y.dtype) 

41 for i in range(0, n, x.shape[0]): 

42 end = min(i + x.shape[0], n) 

43 try: 

44 r[i: end, :] = x[0: end - i, :] 

45 except ValueError as e: # pragma: no cover 

46 raise ValueError( 

47 "Unexpected error: r.shape={} x.shape={} end={} i={}".format( 

48 r.shape, x.shape, end, i)) from e 

49 if y is not None: 

50 if len(y.shape) < 2: 

51 ry[i: end] = y[0: end - i] 

52 else: 

53 ry[i: end, :] = y[0: end - i, :] 

54 if y is None: 

55 return r 

56 return r, ry 

57 

58 

59def benchmark_fct(fct, X, time_limit=4, obs=None, node_time=False, 

60 time_kwargs=None, skip_long_test=True): 

61 """ 

62 Benchmarks a function which takes an array 

63 as an input and changes the number of rows. 

64 

65 @param fct function to benchmark, signature 

66 is `fct(xo)` 

67 @param X array 

68 @param time_limit above this time, measurement is stopped 

69 @param obs all information available in a dictionary 

70 @param node_time measure time execution for each node in the graph 

71 @param time_kwargs to define a more precise way to measure a model 

72 @param skip_long_test skips tests for high values of N if they seem too long 

73 @return dictionary with the results 

74 

75 The function uses *obs* to reduce the number of tries it does. 

76 :epkg:`sklearn:gaussian_process:GaussianProcessRegressor` 

77 produces huge *NxN* if predict method is called 

78 with ``return_cov=True``. 

79 The default for *time_kwargs* is the following: 

80 

81 .. runpython:: 

82 :showcode: 

83 :warningout: DeprecationWarning 

84 

85 from mlprodict.onnxrt.validate.validate_helper import default_time_kwargs 

86 import pprint 

87 pprint.pprint(default_time_kwargs()) 

88 

89 See also notebook :ref:`onnxnodetimerst` to see how this function 

90 can be used to measure time spent in each node. 

91 """ 

92 if time_kwargs is None: 

93 time_kwargs = default_time_kwargs() # pragma: no cover 

94 

95 def make(x, n): 

96 return make_n_rows(x, n) 

97 

98 def allow(N, obs): 

99 if obs is None: 

100 return True # pragma: no cover 

101 prob = obs['problem'] 

102 if "-cov" in prob and N > 1000: 

103 return False # pragma: no cover 

104 return True 

105 

106 Ns = list(sorted(time_kwargs)) 

107 res = {} 

108 for N in Ns: 

109 if not isinstance(N, int): 

110 raise RuntimeError( # pragma: no cover 

111 "time_kwargs ({}) is wrong:\n{}".format( 

112 type(time_kwargs), time_kwargs)) 

113 if not allow(N, obs): 

114 continue # pragma: no cover 

115 x = make(X, N) 

116 number = time_kwargs[N]['number'] 

117 repeat = time_kwargs[N]['repeat'] 

118 if node_time: 

119 fct(x) 

120 main = None 

121 for __ in range(repeat): 

122 agg = None 

123 for _ in range(number): 

124 ms = fct(x)[1] 

125 if agg is None: 

126 agg = ms 

127 for row in agg: 

128 row['N'] = N 

129 else: 

130 if len(agg) != len(ms): 

131 raise RuntimeError( # pragma: no cover 

132 "Not the same number of nodes {} != {}.".format(len(agg), len(ms))) 

133 for a, b in zip(agg, ms): 

134 a['time'] += b['time'] 

135 if main is None: 

136 main = agg 

137 else: 

138 if len(agg) != len(main): 

139 raise RuntimeError( # pragma: no cover 

140 "Not the same number of nodes {} != {}.".format(len(agg), len(main))) 

141 for a, b in zip(main, agg): 

142 a['time'] += b['time'] 

143 a['max_time'] = max( 

144 a.get('max_time', b['time']), b['time']) 

145 a['min_time'] = min( 

146 a.get('min_time', b['time']), b['time']) 

147 for row in main: 

148 row['repeat'] = repeat 

149 row['number'] = number 

150 row['time'] /= repeat * number 

151 if 'max_time' in row: 

152 row['max_time'] /= number 

153 row['min_time'] /= number 

154 else: 

155 row['max_time'] = row['time'] # pragma: no cover 

156 row['min_time'] = row['time'] # pragma: no cover 

157 res[N] = main 

158 else: 

159 res[N] = measure_time(fct, x, repeat=repeat, 

160 number=number, div_by_number=True) 

161 if (skip_long_test and not node_time and 

162 res[N] is not None and 

163 res[N].get('total', time_limit) >= time_limit): 

164 # too long 

165 break # pragma: no cover 

166 if node_time: 

167 rows = [] 

168 for _, v in res.items(): 

169 rows.extend(v) 

170 return rows 

171 return res