Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Functions to show shortened options in :epkg:`asv` benchmarks. 

4""" 

5 

6 

7def expand_onnx_options(model, optim): 

8 """ 

9 Expands shortened options. Long names hide some part 

10 of graphs in :epkg:`asv` benchmark. This trick converts 

11 a string into real conversions options. 

12 

13 @param model model class (:epkg:`scikit-learn`) 

14 @param optim option 

15 @return expanded options 

16 

17 It is the reverse of function @see fn shorten_onnx_options. 

18 The following options are handled: 

19 

20 .. runpython:: 

21 :showcode: 

22 :warningout: DeprecationWarning 

23 

24 from sklearn.linear_model import LogisticRegression 

25 from mlprodict.tools.asv_options_helper import expand_onnx_options 

26 

27 for name in ['cdist', 'nozipmap', 'raw_scores']: 

28 print(name, ':', expand_onnx_options(LogisticRegression, name)) 

29 """ 

30 if optim == 'cdist': 

31 options = {model.__class__: {'optim': 'cdist'}} 

32 elif optim == 'nozipmap': 

33 options = {model.__class__: {'zipmap': False}} 

34 elif optim == 'raw_scores': 

35 options = {model.__class__: {'raw_scores': True, 'zipmap': False}} 

36 else: 

37 options = optim # pragma: no cover 

38 return options 

39 

40 

41def shorten_onnx_options(model, opts): 

42 """ 

43 Shortens onnx options into a string. 

44 Long names hide some part 

45 of graphs in :epkg:`asv` benchmark. 

46 

47 @param model model class (:epkg:`scikit-learn`) 

48 @param opts options 

49 @return shortened options 

50 

51 It is the reverse of function @see fn expand_onnx_options. 

52 """ 

53 if opts is None: 

54 return opts 

55 if opts == {model: {'optim': 'cdist'}}: 

56 return 'cdist' 

57 if opts == {model: {'zipmap': False}}: 

58 return 'nozipmap' 

59 if opts == {model: {'raw_scores': True, 'zipmap': False}}: 

60 return 'raw_scores' 

61 return None 

62 

63 

64def benchmark_version(): 

65 """ 

66 Returns the list of ONNX version to benchmarks. 

67 Following snippet of code shows which version is 

68 current done. 

69 

70 .. runpython:: 

71 :showcode: 

72 :warningout: DeprecationWarning 

73 

74 from mlprodict.tools.asv_options_helper import benchmark_version 

75 print(benchmark_version()) 

76 """ 

77 return [14] # opset=13, 14, ... 

78 

79 

80def ir_version(): 

81 """ 

82 Returns the preferred `IR_VERSION 

83 <https://github.com/onnx/onnx/blob/master/docs/IR.md#onnx-versioning>`_. 

84 

85 .. runpython:: 

86 :showcode: 

87 :warningout: DeprecationWarning 

88 

89 from mlprodict.tools.asv_options_helper import ir_version 

90 print(ir_version()) 

91 """ 

92 return [7] 

93 

94 

95def get_opset_number_from_onnx(benchmark=True): 

96 """ 

97 Retuns the current :epkg:`onnx` opset 

98 based on the installed version of :epkg:`onnx`. 

99 

100 @param benchmark returns the latest 

101 version usable for benchmark 

102 @eturn opset number 

103 """ 

104 if benchmark: 

105 return benchmark_version()[-1] 

106 from onnx.defs import onnx_opset_version # pylint: disable=W0611 

107 return onnx_opset_version() 

108 

109 

110def get_ir_version_from_onnx(benchmark=True): 

111 """ 

112 Retuns the current :epkg:`onnx` :epkg:`IR_VERSION` 

113 based on the installed version of :epkg:`onnx`. 

114 

115 @param benchmark returns the latest 

116 version usable for benchmark 

117 @eturn opset number 

118 

119 .. faqref:: 

120 :title: Failed to load model with error: Unknown model file format version. 

121 :lid: l-onnx-ir-version-fail 

122 

123 :epkg:`onnxruntime` (or ``runtime='onnxruntime1'`` with @see cl OnnxInference) 

124 fails sometimes to load a model showing the following error messsage: 

125 

126 :: 

127 

128 RuntimeError: Unable to create InferenceSession due to '[ONNXRuntimeError] : 

129 2 : INVALID_ARGUMENT : Failed to load model with error: Unknown model file format version.' 

130 

131 This case is due to metadata ``ir_version`` which defines the 

132 :epkg:`IR_VERSION` or *ONNX version*. When a model is machine learned 

133 model is converted, it is usually done with the default version 

134 (``ir_version``) returned by the :epkg:`onnx` package. 

135 :epkg:`onnxruntime` raises the above mentioned error message 

136 when this version (``ir_version``) is too recent. In this case, 

137 :epkg:`onnxruntime` should be updated to the latest version 

138 available or the metadata ``ir_version`` can just be changed to 

139 a lower number. Th function @see fn get_ir_version_from_onnx 

140 returns the latest tested version with *mlprodict*. 

141 

142 .. runpython:: 

143 :showcode: 

144 :warningout: DeprecationWarning 

145 

146 from sklearn.linear_model import LinearRegression 

147 from sklearn.datasets import load_iris 

148 from mlprodict.onnxrt import OnnxInference 

149 import numpy 

150 

151 iris = load_iris() 

152 X = iris.data[:, :2] 

153 y = iris.target 

154 lr = LinearRegression() 

155 lr.fit(X, y) 

156 

157 # Conversion into ONNX. 

158 from mlprodict.onnx_conv import to_onnx 

159 model_onnx = to_onnx(lr, X.astype(numpy.float32), 

160 target_opset=12) 

161 print("ir_version", model_onnx.ir_version) 

162 

163 # Change ir_version 

164 model_onnx.ir_version = 6 

165 

166 # Predictions with onnxruntime 

167 oinf = OnnxInference(model_onnx, runtime='onnxruntime1') 

168 ypred = oinf.run({'X': X[:5].astype(numpy.float32)}) 

169 print("ONNX output:", ypred) 

170 

171 # To avoid keep a fixed version number, you can use 

172 # the value returned by function get_ir_version_from_onnx 

173 from mlprodict.tools import get_ir_version_from_onnx 

174 model_onnx.ir_version = get_ir_version_from_onnx() 

175 print("ir_version", model_onnx.ir_version) 

176 """ 

177 if benchmark: 

178 return ir_version()[-1] 

179 from onnx import IR_VERSION # pylint: disable=W0611 

180 return IR_VERSION 

181 

182 

183def display_onnx(model_onnx, max_length=1000): 

184 """ 

185 Returns a shortened string of the model. 

186 

187 @param model_onnx onnx model 

188 @param max_length maximal string length 

189 @return string 

190 """ 

191 res = str(model_onnx) 

192 if max_length is None or len(res) <= max_length: 

193 return res 

194 begin = res[:max_length // 2] 

195 end = res[-max_length // 2:] 

196 return "\n".join([begin, '[...]', end]) 

197 

198 

199def version2number(vers): 

200 """ 

201 Converts a version number into a real number. 

202 """ 

203 spl = vers.split('.') 

204 r = 0 

205 for i, s in enumerate(spl): 

206 try: 

207 vi = int(s) 

208 except ValueError: 

209 vi = 0 

210 r += vi * 10 ** (-i * 3) 

211 return r