Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief One class which visits a syntax tree.
4"""
5import pprint
6import numpy
9class CodeTranslator:
10 """
11 Class which converts a Python function into
12 something else. It must implements
13 methods *visit* and *depart*.
14 """
16 def __init__(self, visitor):
17 """
18 @param visitor @see cl CodeNodeVisitor
19 """
20 self._visitor = visitor
22 def export(self, context=None, **kwargs):
23 """
24 Exports the parsed :epkg:`python` code
25 into something.
26 """
27 raise NotImplementedError( # pragma: no cover
28 "This function should be overwritten.")
30 def visit(self, node, info):
31 """
32 Visits a node.
34 @param node visited node
35 @param info info extracted by the visitor
36 """
37 raise NotImplementedError( # pragma: no cover
38 "This function should be overwritten.")
40 def depart(self, node, info):
41 """
42 Leaves a node.
44 @param node visited node
45 @param info info extracted by the visitor
46 """
47 raise NotImplementedError( # pragma: no cover
48 "This function should be overwritten.")
51class OnnxTranslator(CodeTranslator):
52 """
53 Class which converts a Python function into
54 an :epkg:`ONNX` function. It must implements
55 methods *visit* and *depart*.
56 """
57 _binary_operators = {
58 'Add': 'Add', 'Div': 'Div',
59 'Mult': 'Mul', 'Sub': 'Sub',
60 'Pow': 'Pow', 'MatMult': 'MatMul',
61 }
63 _unary_operators = {
64 'Sub': 'Neg',
65 }
67 _numpy2onnx_op = {
68 'absolute': 'Abs',
69 'cos': 'Cos',
70 'exp': 'Exp',
71 'power': 'Pow',
72 'transpose': 'Transpose',
73 'sin': 'Sin',
74 # complex function
75 'inner': 'inner',
76 }
78 _parameter_mapping = {
79 'Transpose': {'axes': 'perm'}
80 }
82 class Parameter:
83 """
84 Holds parameter information.
85 """
87 def __init__(self, name, value=('#NODEFAULT#', ), annotation=None):
88 """
89 @param name parameter name
90 @param value parameter value
91 """
92 self.name = name
93 self.value = value
94 self.annotation = annotation
96 @staticmethod
97 def format_value(value):
98 """
99 Returns a formatted value in python code.
100 """
101 if isinstance(value, str):
102 return '"{}"'.format(value.replace('"', '\\"').replace('\\', '\\\\'))
103 if isinstance(value, list):
104 return "[{}]".format(", ".join(map(OnnxTranslator.Parameter.format_value, value)))
105 if isinstance(value, tuple):
106 if value == ('#NODEFAULT#', ):
107 return None
108 return "({})".format(", ".join(map(OnnxTranslator.Parameter.format_value, value)))
109 return str(value)
111 @property
112 def formatted_value(self):
113 """
114 Returns a formatted value in python code.
115 """
116 return OnnxTranslator.Parameter.format_value(self.value)
118 def __str__(self):
119 """
120 Into python syntax.
121 """
122 rows = [self.name]
123 if self.value != ('#NODEFAULT#', ):
124 rows.append('=')
125 rows.append(self.formatted_value)
126 return ''.join(rows)
128 def __init__(self, visitor):
129 """
130 @param visitor @see cl CodeNodeVisitor
131 """
132 CodeTranslator.__init__(self, visitor)
133 self._stack = []
134 self._code_fct = None
136 def _is_stacked(self, name):
137 for line in self._stack:
138 if line[0] == name:
139 return True
140 return False
142 def _get_last(self, name, info=None):
143 if len(self._stack) == 0:
144 raise RuntimeError("Stack is empty.") # pragma: no cover
145 last = self._stack[-1]
146 if ((isinstance(name, str) and last[0] != name) or
147 (isinstance(name, tuple) and last[0] not in name)):
148 raise RuntimeError( # pragma: no cover
149 "Last item is not '{}'\n{}\n---\n{}".format(
150 name, pprint.pformat(self._stack),
151 pprint.pformat(info) if info else ""))
152 return last
154 def make_msg(self, info):
155 """
156 Make a message with line and column information.
157 """
158 lineno = '?'
159 col_offset = '?'
160 if isinstance(info, dict):
161 if 'node' in info:
162 node = info['node']
163 lineno = node.lineno
164 col_offset = node.col_offset
165 else:
166 if 'lineno' in info:
167 lineno = info['lineno']
168 if 'col_offset' in info:
169 col_offset = info['col_offset']
170 else:
171 if hasattr(info, 'lineno'):
172 lineno = info.lineno
173 if hasattr(info, 'col_offset'):
174 col_offset = info.col_offset
176 return "line {}, col {}".format(lineno, col_offset)
178 def export(self, context=None, format='code', # pylint: disable=W0221
179 output_names=None):
180 """
181 Returns an :epkg:`ONNX` graph or a piece
182 of code which could generate the graph.
184 @param context function used in the function code
185 @param format ``'code'``
186 @param output_names add code in the final function
187 to overwrite the names of the
188 outputs in the :epkg:`ONNX` graph
189 @return string or :epkg:`onnx` graph
191 This method is used in function @see fn translate_fct2onnx.
192 An example of code can be found there.
193 """
194 if self._code_fct is None:
195 raise RuntimeError( # pragma: no cover
196 "No python code was parsed.")
197 if context is None:
198 context = {}
200 def find_onnx_correspondance(fct, info):
201 if isinstance(fct, numpy.ufunc):
202 name = fct.__name__
203 elif callable(fct) and getattr(fct, '__module__', '') in (
204 'numpy', 'numpy.core.fromnumeric'):
205 name = fct.__name__
206 elif callable(fct) and fct.__name__.startswith("py_"):
207 return fct
208 else:
209 name = None
210 if name is not None and name not in OnnxTranslator._numpy2onnx_op:
211 raise RuntimeError( # pragma: no cover
212 "Unable to find a correspondance to '{}' at {} in \n{}".format(
213 name, self.make_msg(info),
214 "\n".join(sorted(OnnxTranslator._numpy2onnx_op))))
215 if name is not None:
216 return OnnxTranslator._numpy2onnx_op[name]
217 if isinstance(fct, str):
218 return fct
219 raise RuntimeError( # pragma: no cover
220 "Unable to find a correspondance for function name '{}' in module '{}', "
221 "'{}' (type {}) at {}.".format(
222 name, getattr(fct, '__module__', ''),
223 fct, type(fct), self.make_msg(info)))
225 def write_expression(stack_fct_used, expr, indent, parameter_mapping=None):
226 if isinstance(expr, str):
227 # an argument
228 return ['{}{}'.format(" " * indent * 4, expr)]
229 if isinstance(expr, (int, float)):
230 # an argument
231 return ['{}{}'.format(" " * indent * 4, expr)]
232 if isinstance(expr, OnnxTranslator.Parameter):
233 if parameter_mapping is None:
234 name = expr.name
235 else:
236 name = parameter_mapping.get(expr.name, expr.name)
237 return ["{}{}={}".format(" " * indent * 4, name,
238 expr.formatted_value)]
239 rows = []
240 if isinstance(expr, tuple):
241 expr = [expr]
242 for op, args in expr:
243 if op == 'BinOp':
244 opname = args["op"]
245 opon = args["args"]
246 onnx_name = OnnxTranslator._binary_operators[opname]
247 rows.append(
248 '{}Onnx{}('.format(" " * indent * 4, onnx_name))
249 for expr2 in opon:
250 sexpr2 = write_expression(
251 stack_fct_used, expr2, indent + 1)
252 if any(filter(lambda s: 'op_version="op_version"' in s, sexpr2)):
253 continue # pragma: no cover
254 rows.extend(sexpr2)
255 rows[-1] += ","
256 rows.append('{}op_version=op_version'.format(
257 " " * (indent + 1) * 4))
258 rows.append('{})'.format(" " * indent * 4))
259 elif op == 'UnaryOp':
260 opname = args["op"]
261 opon = args["args"]
262 onnx_name = OnnxTranslator._unary_operators[opname]
263 rows.append(
264 '{}Onnx{}('.format(" " * indent * 4, onnx_name))
265 for expr2 in opon:
266 sexpr2 = write_expression(
267 stack_fct_used, expr2, indent + 1)
268 if any(filter(lambda s: 'op_version="op_version"' in s, sexpr2)):
269 continue
270 rows.extend(sexpr2)
271 rows[-1] += ","
272 rows.append('{}op_version=op_version'.format(
273 " " * (indent + 1) * 4))
274 rows.append('{})'.format(" " * indent * 4))
275 elif op == 'Call':
276 name = args['name']
277 if name.startswith("onnx_"):
278 raise RuntimeError("The code must not use a function prefixed by 'onnx_' (%s). "
279 "It indicates that function manipulate ONNX node and "
280 "the fonction to convert must only deal with arrays." % name)
281 if name not in context:
282 raise RuntimeError(
283 "Unable to find function '{}' at {} in context\n{}\n--\n{}".format(
284 name, self.make_msg(args),
285 '\n'.join(sorted(context)),
286 pprint.pformat(args)))
287 op_conv = find_onnx_correspondance(context[name], args)
288 if callable(op_conv) and op_conv.__name__.startswith('py_'):
289 rows.append(
290 '{}{}('.format(" " * indent * 4, op_conv.__name__))
291 elif callable(op_conv) and op_conv.__name__.startswith('onnx_'):
292 stack_fct_used.append(op_conv.__name__)
293 rows.append(
294 '{}{}('.format(" " * indent * 4, op_conv))
295 else:
296 prefix = "onnx_" if 'a' <= op_conv[0] <= 'z' else 'Onnx'
297 if prefix == "onnx_":
298 stack_fct_used.append(
299 "{}{}".format(prefix, op_conv))
300 prefix = '_' + prefix
301 rows.append(
302 '{}{}{}('.format(" " * indent * 4, prefix, op_conv))
304 opon = args["args"]
305 opon = opon[1:]
306 for expr2 in opon:
307 sexpr2 = write_expression(
308 stack_fct_used, expr2, indent + 1,
309 OnnxTranslator._parameter_mapping.get(op_conv, None))
310 if any(filter(lambda s: 'op_version="op_version"' in s, sexpr2)):
311 continue
312 rows.extend(sexpr2)
313 rows[-1] += ","
314 rows.append('{}op_version=op_version'.format(
315 " " * (indent + 1) * 4))
316 rows.append('{})'.format(" " * indent * 4))
317 else:
318 raise RuntimeError( # pragma: no cover
319 "Unable to interpret '{}'.".format(expr))
320 return rows
322 def write_function(stack_fct_used, to_replaces, node):
323 rows = []
324 name, args = node
325 if name != 'FunctionDef':
326 raise RuntimeError( # pragma: no cover
327 "The code being translated should be a single function not "
328 "'{}' at {}.".format(name, self.make_msg(args)))
329 list_args = list(map(str, args['args']))
330 if all(map(lambda s: 'dtype=' not in s, list_args)):
331 list_args.append("dtype=numpy.float32")
332 if all(map(lambda s: 'op_version=' not in s, list_args)):
333 list_args.append("op_version=None")
334 fct_name = args['name']
335 rows.append("def {}({}):".format(
336 fct_name, ', '.join(list_args)))
337 indent = 1
339 to_replace = "# __HEADER__{}".format(id(node))
340 to_replaces.append(to_replace)
341 rows.append("{}{}".format(" " * (indent * 4), to_replace))
343 code = args['code']
344 for op, args in code:
345 if op == "Assign":
346 name = args['name']
347 args = args["args"]
348 rows.append("{}{} = (".format(" " * (indent * 4), name))
349 rows.extend(write_expression(
350 stack_fct_used, args, indent + 1))
351 rows.append("{})".format(" " * (indent * 4)))
352 elif op == "Return":
353 args = args["code"]
354 if output_names is None:
355 rows.append("{}return (".format(" " * (indent * 4)))
356 rows.extend(write_expression(
357 stack_fct_used, args, indent + 1))
358 rows.append("{})".format(" " * (indent * 4)))
359 else:
360 rows.append(
361 "{}return OnnxIdentity(".format(" " * (indent * 4)))
362 subrows = write_expression(
363 stack_fct_used, args, indent + 1)
364 subrows[-1] += ","
365 rows.extend(subrows)
366 rows.append("{}output_names={},".format(
367 " " * ((indent + 1) * 4), str(output_names)))
368 rows.append("{}op_version=op_version".format(
369 " " * ((indent + 1) * 4)))
370 rows.append("{})".format(" " * (indent * 4)))
371 else:
372 raise RuntimeError( # pragma: no cover
373 "Unable to process operator '{}' at {}. "
374 "Make sure it is either an affectation, "
375 "either a return.".format(op, self.make_msg(args)))
376 return rows
378 stack_fct_used = []
379 to_replaces = []
380 rows = write_function(stack_fct_used, to_replaces, self._code_fct)
382 # handling dtype parameter
383 if len(to_replaces) != 1:
384 raise RuntimeError( # pragma: no cover
385 "The following code misses a placeholder:\n{}".format(
386 "\n".join(rows)))
387 index = -1
388 for i, row in enumerate(rows):
389 if to_replaces[0] in row:
390 index = i
391 break
393 header = []
394 for fct in stack_fct_used:
395 header.append(
396 " _{0} = lambda *args, op_version=op_version, **kwargs: {0}(*args, dtype=dtype, "
397 "op_version=op_version, **kwargs)".format(fct))
398 if len(header) > 0:
399 header.append('')
400 rows[index:index + 1] = header
402 return "\n".join(rows)
404 def visit(self, node, info):
405 """
406 Visits a node.
408 @param node visited node
409 @param info info extracted by the visitor
410 """
411 if 'type' not in info:
412 return
414 kind = info['type']
415 if kind == "Module":
416 return
417 if kind == "FunctionDef":
418 if self._is_stacked('FunctionDef'):
419 raise RuntimeError("Nested functions are not allowed at {}.".format(
420 self.make_msg(node)))
421 self._stack.append(
422 ('FunctionDef', {'args': [], 'code': [], 'name': info['name'], 'default': [],
423 'lineno': node.lineno, 'col_offset': node.col_offset}))
424 return
425 if kind == "arguments":
426 _, buf = self._get_last('FunctionDef')
427 return
428 if kind == "arg":
429 return
430 if kind == "Assign":
431 self._stack.append(
432 ('Assign', {'args': [], 'lineno': node.lineno, 'col_offset': node.col_offset}))
433 return
434 if kind in ('Name', 'Cst'):
435 self._get_last(
436 ('Assign', 'BinOp', 'Call', 'Return', 'FunctionDef', 'keyword', 'UnaryOp'))
437 return
438 if kind == 'BinOp':
439 self._stack.append(
440 ('BinOp', {'args': [], 'lineno': node.lineno, 'col_offset': node.col_offset}))
441 return
442 if kind == 'UnaryOp':
443 self._stack.append(
444 ('UnaryOp', {'args': [], 'lineno': node.lineno, 'col_offset': node.col_offset}))
445 return
446 if kind in OnnxTranslator._binary_operators:
447 _, buf = self._get_last(('BinOp', 'UnaryOp'))
448 buf['op'] = kind
449 return
450 if kind == 'Call':
451 self._stack.append(
452 ('Call', {'name': info['str'], 'args': [], 'lineno': node.lineno,
453 'col_offset': node.col_offset}))
454 return
455 if kind == 'Return':
456 self._get_last('FunctionDef')
457 self._stack.append(
458 ('Return', {'code': [], 'lineno': node.lineno, 'col_offset': node.col_offset}))
459 return
460 if kind == "Attribute":
461 if info.get('str', '') == 'T':
462 raise NotImplementedError( # pragma: no cover
463 "Transpose should be done with numpy.transpose not with .T'{}' "
464 "at {}\n{}\n---\n{}".format(
465 info.get('type', '?'), self.make_msg(node),
466 pprint.pformat(info), pprint.pformat(self._stack)))
467 self._get_last('Call')
468 return
469 if kind == 'keyword':
470 self._get_last('Call')
471 self._stack.append(
472 ('keyword', {'name': "{0}".format(node.arg),
473 'lineno': getattr(node, 'lineno', '?'),
474 'col_offset': getattr(node, 'col_offset', '?')}))
475 return
476 if kind == 'List':
477 self._get_last('keyword')
478 self._stack.append(
479 ('List', {'elts': [], 'lineno': getattr(node, 'lineno', '?'),
480 'col_offset': getattr(node, 'col_offset', '?')}))
481 return
482 if kind == 'Num':
483 self._get_last(('List', 'UnaryOp', 'BinOp', 'FunctionDef', 'Call'))
484 return
485 if kind == 'Str':
486 self._get_last('keyword')
487 return
489 raise NotImplementedError( # pragma: no cover
490 "Unable to interpret kind '{}' at {}\n{}\n---\n{}".format(
491 info.get('type', '?'), self.make_msg(
492 node), pprint.pformat(info),
493 pprint.pformat(self._stack)))
495 def _fix_default_values(self, code_fct):
496 """
497 Maps default values with parameter names.
498 """
499 nbdef = len(code_fct[1]['default'])
500 nbpar = len(code_fct[1]['args'])
501 args = []
502 for i in range(nbpar):
503 name, annotation = code_fct[1]['args'][i]
504 j = nbdef - (nbpar - i)
505 if j >= 0:
506 default = code_fct[1]['default'][j]
507 p = OnnxTranslator.Parameter(
508 name, annotation=annotation, value=default)
509 else:
510 p = OnnxTranslator.Parameter(name, annotation=annotation)
511 args.append(p)
512 code_fct[1]['args'] = args
514 def _post_process(self, op, node):
515 """
516 Simplifies some operator such as ``OnnxNeg(2)``.
517 """
518 if op is None and 'args' in node:
519 for i in range(len(node['args'])):
520 if not isinstance(node['args'][i], tuple):
521 continue
522 o, v = node['args'][i]
523 if (o == 'UnaryOp' and len(v['args']) == 1 and
524 isinstance(v['args'][0], (int, float, numpy.int64,
525 numpy.float32, numpy.float64))):
526 if v['op'] == 'Sub':
527 node['args'][i] = -v['args'][0]
529 def depart(self, node, info):
530 """
531 Visits a node.
533 @param node visited node
534 @param info info extracted by the visitor
535 """
536 if 'type' not in info:
537 return
539 kind = info['type']
540 if kind == "arg":
541 return
542 if kind == "arguments":
543 _, buf = self._get_last('FunctionDef')
544 for child in info['children']:
545 if child['type'] == 'Str':
546 buf['default'].append(child['str'])
547 elif child['type'] in ('Num', 'Cst'):
548 buf['default'].append(child['n'])
549 elif child['type'] == 'arg':
550 buf['args'].append(
551 (child['str'], child.get('annotation', None)))
552 else:
553 raise RuntimeError( # pragma: no cover
554 "Unable to interpret type '{}' in function definition."
555 "\n{}".format(
556 child['type'], pprint.pformat(info)))
557 return
559 if kind == "Name":
560 op, buf = self._get_last(
561 ('Assign', 'BinOp', 'Call', 'Return', 'FunctionDef', 'keyword',
562 'UnaryOp'),
563 info)
564 if op == 'Assign':
565 buf['name'] = info['str']
566 return
567 elif op in ('BinOp', 'Call'):
568 buf['args'].append(info['str'])
569 return
570 elif op == 'Return':
571 buf['code'] = info['str']
572 return
573 elif op == 'keyword':
574 buf['value'] = info['str']
575 return
576 elif op == 'UnaryOp':
577 buf['args'].append(info['str'])
578 return
579 elif op == 'FunctionDef':
580 raise RuntimeError("Default value must be constant, variable '{}' was "
581 "detected.".format(info['str']))
583 if kind in OnnxTranslator._binary_operators:
584 _, buf = self._get_last(('BinOp', 'UnaryOp'))
585 return
586 if kind in ('Call', 'BinOp', 'Assign', 'Return', 'UnaryOp'):
587 op, buf = self._get_last(
588 ('Call', 'BinOp', 'Assign', 'Return', 'UnaryOp'))
589 self._post_process(op, buf)
590 self._stack.pop()
591 opp, parent = self._get_last(
592 ('Call', 'BinOp', 'Assign', 'FunctionDef', 'Return', 'UnaryOp'))
593 if opp in ('FunctionDef', 'Return'):
594 parent['code'].append((op, buf))
595 else:
596 parent['args'].append((op, buf))
597 self._post_process(None, parent)
598 return
599 if kind == 'FunctionDef':
600 if len(self._stack) == 1:
601 self._code_fct = self._stack[-1]
602 self._fix_default_values(self._code_fct)
603 self._stack = []
604 return
605 if kind == 'Module':
606 return
607 if kind == 'Attribute':
608 op, buf = self._get_last(('Call', 'BinOp'))
610 if len(info["children"]) > 0:
611 fir = info["children"][0]
612 if fir["type"] == "Name":
613 parent = fir["node"].id
614 info["str"] = "{0}.{1}".format(parent, info["str"])
615 info["children"][0]["remove"] = True
617 buf['name'] = info["str"]
618 buf['args'][0] = info["str"]
619 return
620 if kind in ('Num', 'Cst'):
621 op, buf = self._get_last(
622 ('List', 'BinOp', 'UnaryOp', 'FunctionDef', 'Call'))
623 if op == 'FunctionDef':
624 return
625 if op == 'List':
626 buf['elts'].append(info['n'])
627 else:
628 buf['args'].append(info['n'])
629 return
630 if kind == 'Str':
631 _, buf = self._get_last('keyword')
632 buf['value'] = info['str']
633 return
634 if kind == 'List':
635 op, buf = self._get_last('List')
636 value = buf['elts']
637 self._post_process(op, buf)
638 self._stack.pop()
639 opp, parent = self._get_last('keyword')
640 parent['value'] = value
641 self._post_process(None, parent)
642 return
643 if kind == 'keyword':
644 op, buf = self._get_last('keyword')
645 name = buf["name"]
646 if 'value' not in buf:
647 raise RuntimeError(str(buf)) # pragma: no cover
648 value = buf['value']
649 self._post_process(op, buf)
650 self._stack.pop()
651 opp, parent = self._get_last('Call')
652 parent['args'].append(OnnxTranslator.Parameter(name, value))
653 self._post_process(None, parent)
654 return
656 raise NotImplementedError( # pragma: no cover
657 "Unable to interpret kind '{}' at {}\n{}\n---\n{}".format(
658 info.get('type', '?'), self.make_msg(
659 node), pprint.pformat(info),
660 pprint.pformat(self._stack)))