Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief One class which visits a syntax tree. 

4""" 

5import inspect 

6import ast 

7from textwrap import dedent 

8import numpy 

9from scipy.spatial.distance import squareform, pdist 

10from .node_visitor_translator import CodeNodeVisitor 

11 

12 

13def py_make_float_array(cst, op_version=None): 

14 """ 

15 Creates an array with a single element 

16 from a constant. 

17 

18 @param cst constant 

19 @param op_version unused 

20 @return array 

21 

22 .. runpython:: 

23 :showcode: 

24 :warningout: DeprecationWarning 

25 

26 from mlprodict.onnx_grammar.onnx_translation import py_make_float_array 

27 print(py_make_float_array(5.5)) 

28 """ 

29 return numpy.array([cst], dtype=numpy.float32) 

30 

31 

32def py_pow(x, p, op_version=None): 

33 """ 

34 Function for python operator ``**``. 

35 

36 @param x float 

37 @param p power 

38 @param op_version unused 

39 @return :math:`x^p` 

40 """ 

41 return x ** p 

42 

43 

44def py_mul(*x, op_version=None): 

45 """ 

46 Function for python operator ``*``. 

47 

48 @param x floats 

49 @param op_version unused 

50 @return `x*y` 

51 """ 

52 if len(x) == 2: 

53 return x[0] * x[1] 

54 p = x[0] 

55 for y in x[1:]: 

56 p *= y 

57 return p 

58 

59 

60def py_opp(x, op_version=None): 

61 """ 

62 Function for python unary operator ``-``. 

63 

64 @param x floats 

65 @param op_version unused 

66 @return `-x` 

67 """ 

68 return -x 

69 

70 

71def squareform_pdist(X, metric='sqeuclidean', op_version=None): 

72 """ 

73 Replacements for `squareform 

74 <http://scipy.github.io/devdocs/generated/scipy.spatial.distance.squareform.html>`_ 

75 and `pdist 

76 <http://scipy.github.io/devdocs/generated/scipy.spatial.distance.pdist.html>`_. 

77 """ 

78 return squareform(pdist(X, metric=metric)) 

79 

80 

81def get_default_context(): 

82 """ 

83 Returns a default context useful for most of the conversion 

84 from a function using :epkg:`numpy` into :epkg:`ONNX`. 

85 """ 

86 context = {'py_pow': py_pow, 'py_make_float_array': py_make_float_array, 

87 'py_mul': py_mul, 'py_opp': py_opp, 

88 'cdist': 'cdist', 'squareform_pdist': 'squareform_pdist'} 

89 allow = set(('abs add ceil arccos arccosh arcsin arcsinh arctan arctanh ceil cos cosh divide' 

90 'equal exp floor greater invert less log matmul maximum minimum mod' 

91 'multiply power sign sin sinh sqrt square subtract tan tanh transpose').split()) 

92 for k, v in numpy.__dict__.items(): 

93 if k not in allow: 

94 continue 

95 context['numpy.%s' % k] = v 

96 context['np.%s' % k] = v 

97 return context 

98 

99 

100def get_default_context_cpl(): 

101 """ 

102 Returns a default useful context to compile the converter 

103 returned by @see fn translate_fct2onnx. 

104 """ 

105 ctx = {'py_make_float_array': py_make_float_array, 

106 'py_pow': py_pow, 'py_mul': py_mul, 'py_opp': py_opp, 

107 'numpy': numpy} 

108 try: 

109 from skl2onnx.algebra.complex_functions import onnx_squareform_pdist 

110 from skl2onnx.algebra.complex_functions import onnx_cdist 

111 ctx['onnx_squareform_pdist'] = onnx_squareform_pdist 

112 ctx['onnx_cdist'] = onnx_cdist 

113 except ImportError: # pragma: no cover 

114 # Too old version for skl2onnx. 

115 pass 

116 

117 from skl2onnx.algebra import onnx_ops 

118 from skl2onnx.algebra.onnx_operator import OnnxOperator 

119 d = onnx_ops.__dict__ 

120 for k, v in d.items(): 

121 try: 

122 if k.startswith("Onnx") and issubclass(v, OnnxOperator): 

123 ctx[k] = v 

124 except TypeError as e: 

125 if inspect.isfunction(v): 

126 continue 

127 raise RuntimeError( # pragma: no cover 

128 "Issue with {}={} (type={})".format(k, v, type(v))) from e 

129 return ctx 

130 

131 

132def translate_fct2onnx(fct, context=None, cpl=False, 

133 context_cpl=None, output_names=None, 

134 dtype=numpy.float32, 

135 verbose=0, fLOG=None): 

136 """ 

137 Translates a function into :epkg:`ONNX`. The code it produces 

138 is using classes *OnnxAbs*, *OnnxAdd*, ... 

139 

140 @param fct function to convert 

141 @param context context of the function to convert 

142 something like ``{'numpy.transpose': numpy.transpose}``, 

143 if *context* is None, it receives a default value 

144 returnd by @see fn get_default_context 

145 @param cpl compile the function after it was 

146 created 

147 @param context_cpl context used at compiling time 

148 if *context_cpl* is None, it receives a default value 

149 returnd by @see fn get_default_context_cpl 

150 @param output_names names of the output in the :epkg:`ONNX` graph 

151 @param dtype :epkg:`numpy` float type used to produce the model 

152 @param verbose integer, display more information 

153 @param fLOG logging function 

154 @return code or compiled code 

155 

156 .. exref:: 

157 :title: Convert a function into ONNX code 

158 

159 The following code parses a python function and returns 

160 another python function which produces an :epkg:`ONNX` 

161 graph if executed. 

162 

163 .. runpython:: 

164 :showcode: 

165 :warningout: DeprecationWarning 

166 :process: 

167 :store_in_file: fct2onnx2.py 

168 

169 import numpy 

170 from mlprodict.onnx_grammar import translate_fct2onnx 

171 

172 def trs(x, y): 

173 z = x + numpy.transpose(y, axes=[1, 0]) 

174 return x * z 

175 

176 onnx_code = translate_fct2onnx( 

177 trs, context={'numpy.transpose': numpy.transpose}) 

178 print(onnx_code) 

179 

180 Next example goes further and compile the outcome. 

181 

182 .. exref:: 

183 :title: Convert a function into ONNX code and run 

184 

185 The following code parses a python function and returns 

186 another python function which produces an :epkg:`ONNX` 

187 graph if executed. The example executes the function, 

188 creates an :epkg:`ONNX` then uses @see cl OnnxInference 

189 to compute *predictions*. Finally it compares 

190 them to the original. 

191 

192 .. runpython:: 

193 :showcode: 

194 :warningout: DeprecationWarning 

195 :process: 

196 :store_in_file: fct2onnx3.py 

197 

198 import numpy 

199 from mlprodict.onnx_grammar import translate_fct2onnx 

200 from mlprodict.onnxrt import OnnxInference 

201 from skl2onnx.algebra.onnx_ops import ( 

202 OnnxAdd, OnnxTranspose, OnnxMul, OnnxIdentity 

203 ) 

204 

205 ctx = {'OnnxAdd': OnnxAdd, 

206 'OnnxTranspose': OnnxTranspose, 

207 'OnnxMul': OnnxMul, 

208 'OnnxIdentity': OnnxIdentity} 

209 

210 def trs(x, y): 

211 z = x + numpy.transpose(y, axes=[1, 0]) 

212 return x * z 

213 

214 inputs = {'x': numpy.array([[1, 2]], dtype=numpy.float32), 

215 'y': numpy.array([[-0.3, 0.4]], dtype=numpy.float32).T} 

216 

217 original = trs(inputs['x'], inputs['y']) 

218 

219 print('original output:', original) 

220 

221 onnx_fct = translate_fct2onnx( 

222 trs, context={'numpy.transpose': numpy.transpose}, 

223 cpl=True, context_cpl=ctx, output_names=['Z']) 

224 

225 onnx_code = onnx_fct('x', 'y', opset_version=12) 

226 print('ONNX code:', onnx_code) 

227 

228 onnx_g = onnx_code.to_onnx(inputs, target_opset=12) 

229 

230 oinf = OnnxInference(onnx_g) 

231 res = oinf.run(inputs) 

232 

233 print("ONNX inference:", res['Z']) 

234 print("ONNX graph:", onnx_g) 

235 

236 The function to be converted may include python functions 

237 which must not be converted. In that case, their name 

238 must be prefixed by ``py_``. The execution of the function 

239 this one builds produces the following error:: 

240 

241 TypeError: Parameter to MergeFrom() must be instance of same class: 

242 expected onnx.TensorProto got onnx.AttributeProto. 

243 

244 It indicates that constants in the code marges multiple types, 

245 usually floats and tensor of floats. Floats should be converted 

246 using the following function:: 

247 

248 def py_make_float_array(cst): 

249 return numpy.array([cst], dtype=numpy.float32) 

250 

251 The function replaces empty contexts by default values which 

252 covers many :epkg:`numpy` functions. The tutorial 

253 :ref:`l-onnx-tutorial` gives an example of how it can be used 

254 on a more complex function. 

255 """ 

256 def compile_code(name, code, context=None): 

257 """ 

258 Compiles a python function with the given 

259 context. 

260 

261 @param name function name 

262 @param code python code 

263 @param context context used at compilation 

264 @return compiled function 

265 """ 

266 if context is None: 

267 context = {} # pragma: no cover 

268 try: 

269 obj = compile(code, "", "exec") 

270 except SyntaxError as e: # pragma: no cover 

271 raise SyntaxError("Unable to compile\n{}".format(code)) from e 

272 context_g = context.copy() 

273 context_l = context.copy() 

274 exec(obj, context_g, context_l) # pylint: disable=W0122 

275 return context_l[name] 

276 

277 if isinstance(fct, str): 

278 code = fct 

279 elif callable(fct): 

280 code = inspect.getsource(fct) 

281 else: 

282 raise TypeError( # pragma: no cover 

283 "Unable to guess code from type {}.".format(type(fct))) 

284 node = ast.parse(dedent(code)) 

285 v = CodeNodeVisitor() 

286 v.visit(node) 

287 if context is None: 

288 context = get_default_context() 

289 onnx_code = v.export(context=context, 

290 output_names=output_names) 

291 if not cpl: 

292 return onnx_code 

293 if verbose > 0 and fLOG is not None: # pragma: no cover 

294 fLOG('[translate_fct2onnx] python code') 

295 fLOG(code) 

296 fLOG('[translate_fct2onnx] ONNX code') 

297 fLOG(onnx_code) 

298 if context_cpl is None: 

299 context_cpl = get_default_context_cpl() 

300 if 'numpy' not in context_cpl: 

301 context_cpl = context_cpl.copy() 

302 context_cpl['numpy'] = numpy 

303 return compile_code(fct.__name__, onnx_code, context_cpl)