Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief One class which visits a syntax tree.
4"""
5import inspect
6import ast
7from textwrap import dedent
8import numpy
9from scipy.spatial.distance import squareform, pdist
10from .node_visitor_translator import CodeNodeVisitor
13def py_make_float_array(cst, op_version=None):
14 """
15 Creates an array with a single element
16 from a constant.
18 @param cst constant
19 @param op_version unused
20 @return array
22 .. runpython::
23 :showcode:
24 :warningout: DeprecationWarning
26 from mlprodict.onnx_grammar.onnx_translation import py_make_float_array
27 print(py_make_float_array(5.5))
28 """
29 return numpy.array([cst], dtype=numpy.float32)
32def py_pow(x, p, op_version=None):
33 """
34 Function for python operator ``**``.
36 @param x float
37 @param p power
38 @param op_version unused
39 @return :math:`x^p`
40 """
41 return x ** p
44def py_mul(*x, op_version=None):
45 """
46 Function for python operator ``*``.
48 @param x floats
49 @param op_version unused
50 @return `x*y`
51 """
52 if len(x) == 2:
53 return x[0] * x[1]
54 p = x[0]
55 for y in x[1:]:
56 p *= y
57 return p
60def py_opp(x, op_version=None):
61 """
62 Function for python unary operator ``-``.
64 @param x floats
65 @param op_version unused
66 @return `-x`
67 """
68 return -x
71def squareform_pdist(X, metric='sqeuclidean', op_version=None):
72 """
73 Replacements for `squareform
74 <http://scipy.github.io/devdocs/generated/scipy.spatial.distance.squareform.html>`_
75 and `pdist
76 <http://scipy.github.io/devdocs/generated/scipy.spatial.distance.pdist.html>`_.
77 """
78 return squareform(pdist(X, metric=metric))
81def get_default_context():
82 """
83 Returns a default context useful for most of the conversion
84 from a function using :epkg:`numpy` into :epkg:`ONNX`.
85 """
86 context = {'py_pow': py_pow, 'py_make_float_array': py_make_float_array,
87 'py_mul': py_mul, 'py_opp': py_opp,
88 'cdist': 'cdist', 'squareform_pdist': 'squareform_pdist'}
89 allow = set(('abs add ceil arccos arccosh arcsin arcsinh arctan arctanh ceil cos cosh divide'
90 'equal exp floor greater invert less log matmul maximum minimum mod'
91 'multiply power sign sin sinh sqrt square subtract tan tanh transpose').split())
92 for k, v in numpy.__dict__.items():
93 if k not in allow:
94 continue
95 context['numpy.%s' % k] = v
96 context['np.%s' % k] = v
97 return context
100def get_default_context_cpl():
101 """
102 Returns a default useful context to compile the converter
103 returned by @see fn translate_fct2onnx.
104 """
105 ctx = {'py_make_float_array': py_make_float_array,
106 'py_pow': py_pow, 'py_mul': py_mul, 'py_opp': py_opp,
107 'numpy': numpy}
108 try:
109 from skl2onnx.algebra.complex_functions import onnx_squareform_pdist
110 from skl2onnx.algebra.complex_functions import onnx_cdist
111 ctx['onnx_squareform_pdist'] = onnx_squareform_pdist
112 ctx['onnx_cdist'] = onnx_cdist
113 except ImportError: # pragma: no cover
114 # Too old version for skl2onnx.
115 pass
117 from skl2onnx.algebra import onnx_ops
118 from skl2onnx.algebra.onnx_operator import OnnxOperator
119 d = onnx_ops.__dict__
120 for k, v in d.items():
121 try:
122 if k.startswith("Onnx") and issubclass(v, OnnxOperator):
123 ctx[k] = v
124 except TypeError as e:
125 if inspect.isfunction(v):
126 continue
127 raise RuntimeError( # pragma: no cover
128 "Issue with {}={} (type={})".format(k, v, type(v))) from e
129 return ctx
132def translate_fct2onnx(fct, context=None, cpl=False,
133 context_cpl=None, output_names=None,
134 dtype=numpy.float32,
135 verbose=0, fLOG=None):
136 """
137 Translates a function into :epkg:`ONNX`. The code it produces
138 is using classes *OnnxAbs*, *OnnxAdd*, ...
140 @param fct function to convert
141 @param context context of the function to convert
142 something like ``{'numpy.transpose': numpy.transpose}``,
143 if *context* is None, it receives a default value
144 returnd by @see fn get_default_context
145 @param cpl compile the function after it was
146 created
147 @param context_cpl context used at compiling time
148 if *context_cpl* is None, it receives a default value
149 returnd by @see fn get_default_context_cpl
150 @param output_names names of the output in the :epkg:`ONNX` graph
151 @param dtype :epkg:`numpy` float type used to produce the model
152 @param verbose integer, display more information
153 @param fLOG logging function
154 @return code or compiled code
156 .. exref::
157 :title: Convert a function into ONNX code
159 The following code parses a python function and returns
160 another python function which produces an :epkg:`ONNX`
161 graph if executed.
163 .. runpython::
164 :showcode:
165 :warningout: DeprecationWarning
166 :process:
167 :store_in_file: fct2onnx2.py
169 import numpy
170 from mlprodict.onnx_grammar import translate_fct2onnx
172 def trs(x, y):
173 z = x + numpy.transpose(y, axes=[1, 0])
174 return x * z
176 onnx_code = translate_fct2onnx(
177 trs, context={'numpy.transpose': numpy.transpose})
178 print(onnx_code)
180 Next example goes further and compile the outcome.
182 .. exref::
183 :title: Convert a function into ONNX code and run
185 The following code parses a python function and returns
186 another python function which produces an :epkg:`ONNX`
187 graph if executed. The example executes the function,
188 creates an :epkg:`ONNX` then uses @see cl OnnxInference
189 to compute *predictions*. Finally it compares
190 them to the original.
192 .. runpython::
193 :showcode:
194 :warningout: DeprecationWarning
195 :process:
196 :store_in_file: fct2onnx3.py
198 import numpy
199 from mlprodict.onnx_grammar import translate_fct2onnx
200 from mlprodict.onnxrt import OnnxInference
201 from skl2onnx.algebra.onnx_ops import (
202 OnnxAdd, OnnxTranspose, OnnxMul, OnnxIdentity
203 )
205 ctx = {'OnnxAdd': OnnxAdd,
206 'OnnxTranspose': OnnxTranspose,
207 'OnnxMul': OnnxMul,
208 'OnnxIdentity': OnnxIdentity}
210 def trs(x, y):
211 z = x + numpy.transpose(y, axes=[1, 0])
212 return x * z
214 inputs = {'x': numpy.array([[1, 2]], dtype=numpy.float32),
215 'y': numpy.array([[-0.3, 0.4]], dtype=numpy.float32).T}
217 original = trs(inputs['x'], inputs['y'])
219 print('original output:', original)
221 onnx_fct = translate_fct2onnx(
222 trs, context={'numpy.transpose': numpy.transpose},
223 cpl=True, context_cpl=ctx, output_names=['Z'])
225 onnx_code = onnx_fct('x', 'y', opset_version=12)
226 print('ONNX code:', onnx_code)
228 onnx_g = onnx_code.to_onnx(inputs, target_opset=12)
230 oinf = OnnxInference(onnx_g)
231 res = oinf.run(inputs)
233 print("ONNX inference:", res['Z'])
234 print("ONNX graph:", onnx_g)
236 The function to be converted may include python functions
237 which must not be converted. In that case, their name
238 must be prefixed by ``py_``. The execution of the function
239 this one builds produces the following error::
241 TypeError: Parameter to MergeFrom() must be instance of same class:
242 expected onnx.TensorProto got onnx.AttributeProto.
244 It indicates that constants in the code marges multiple types,
245 usually floats and tensor of floats. Floats should be converted
246 using the following function::
248 def py_make_float_array(cst):
249 return numpy.array([cst], dtype=numpy.float32)
251 The function replaces empty contexts by default values which
252 covers many :epkg:`numpy` functions. The tutorial
253 :ref:`l-onnx-tutorial` gives an example of how it can be used
254 on a more complex function.
255 """
256 def compile_code(name, code, context=None):
257 """
258 Compiles a python function with the given
259 context.
261 @param name function name
262 @param code python code
263 @param context context used at compilation
264 @return compiled function
265 """
266 if context is None:
267 context = {} # pragma: no cover
268 try:
269 obj = compile(code, "", "exec")
270 except SyntaxError as e: # pragma: no cover
271 raise SyntaxError("Unable to compile\n{}".format(code)) from e
272 context_g = context.copy()
273 context_l = context.copy()
274 exec(obj, context_g, context_l) # pylint: disable=W0122
275 return context_l[name]
277 if isinstance(fct, str):
278 code = fct
279 elif callable(fct):
280 code = inspect.getsource(fct)
281 else:
282 raise TypeError( # pragma: no cover
283 "Unable to guess code from type {}.".format(type(fct)))
284 node = ast.parse(dedent(code))
285 v = CodeNodeVisitor()
286 v.visit(node)
287 if context is None:
288 context = get_default_context()
289 onnx_code = v.export(context=context,
290 output_names=output_names)
291 if not cpl:
292 return onnx_code
293 if verbose > 0 and fLOG is not None: # pragma: no cover
294 fLOG('[translate_fct2onnx] python code')
295 fLOG(code)
296 fLOG('[translate_fct2onnx] ONNX code')
297 fLOG(onnx_code)
298 if context_cpl is None:
299 context_cpl = get_default_context_cpl()
300 if 'numpy' not in context_cpl:
301 context_cpl = context_cpl.copy()
302 context_cpl['numpy'] = numpy
303 return compile_code(fct.__name__, onnx_code, context_cpl)