Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief One class which visits a syntax tree. 

4""" 

5import pprint 

6import numpy 

7 

8 

9class CodeTranslator: 

10 """ 

11 Class which converts a Python function into 

12 something else. It must implements 

13 methods *visit* and *depart*. 

14 """ 

15 

16 def __init__(self, visitor): 

17 """ 

18 @param visitor @see cl CodeNodeVisitor 

19 """ 

20 self._visitor = visitor 

21 

22 def export(self, context=None, **kwargs): 

23 """ 

24 Exports the parsed :epkg:`python` code 

25 into something. 

26 """ 

27 raise NotImplementedError( # pragma: no cover 

28 "This function should be overwritten.") 

29 

30 def visit(self, node, info): 

31 """ 

32 Visits a node. 

33 

34 @param node visited node 

35 @param info info extracted by the visitor 

36 """ 

37 raise NotImplementedError( # pragma: no cover 

38 "This function should be overwritten.") 

39 

40 def depart(self, node, info): 

41 """ 

42 Leaves a node. 

43 

44 @param node visited node 

45 @param info info extracted by the visitor 

46 """ 

47 raise NotImplementedError( # pragma: no cover 

48 "This function should be overwritten.") 

49 

50 

51class OnnxTranslator(CodeTranslator): 

52 """ 

53 Class which converts a Python function into 

54 an :epkg:`ONNX` function. It must implements 

55 methods *visit* and *depart*. 

56 """ 

57 _binary_operators = { 

58 'Add': 'Add', 'Div': 'Div', 

59 'Mult': 'Mul', 'Sub': 'Sub', 

60 'Pow': 'Pow', 'MatMult': 'MatMul', 

61 } 

62 

63 _unary_operators = { 

64 'Sub': 'Neg', 

65 } 

66 

67 _numpy2onnx_op = { 

68 'absolute': 'Abs', 

69 'cos': 'Cos', 

70 'exp': 'Exp', 

71 'power': 'Pow', 

72 'transpose': 'Transpose', 

73 'sin': 'Sin', 

74 # complex function 

75 'inner': 'inner', 

76 } 

77 

78 _parameter_mapping = { 

79 'Transpose': {'axes': 'perm'} 

80 } 

81 

82 class Parameter: 

83 """ 

84 Holds parameter information. 

85 """ 

86 

87 def __init__(self, name, value=('#NODEFAULT#', ), annotation=None): 

88 """ 

89 @param name parameter name 

90 @param value parameter value 

91 """ 

92 self.name = name 

93 self.value = value 

94 self.annotation = annotation 

95 

96 @staticmethod 

97 def format_value(value): 

98 """ 

99 Returns a formatted value in python code. 

100 """ 

101 if isinstance(value, str): 

102 return '"{}"'.format(value.replace('"', '\\"').replace('\\', '\\\\')) 

103 if isinstance(value, list): 

104 return "[{}]".format(", ".join(map(OnnxTranslator.Parameter.format_value, value))) 

105 if isinstance(value, tuple): 

106 if value == ('#NODEFAULT#', ): 

107 return None 

108 return "({})".format(", ".join(map(OnnxTranslator.Parameter.format_value, value))) 

109 return str(value) 

110 

111 @property 

112 def formatted_value(self): 

113 """ 

114 Returns a formatted value in python code. 

115 """ 

116 return OnnxTranslator.Parameter.format_value(self.value) 

117 

118 def __str__(self): 

119 """ 

120 Into python syntax. 

121 """ 

122 rows = [self.name] 

123 if self.value != ('#NODEFAULT#', ): 

124 rows.append('=') 

125 rows.append(self.formatted_value) 

126 return ''.join(rows) 

127 

128 def __init__(self, visitor): 

129 """ 

130 @param visitor @see cl CodeNodeVisitor 

131 """ 

132 CodeTranslator.__init__(self, visitor) 

133 self._stack = [] 

134 self._code_fct = None 

135 

136 def _is_stacked(self, name): 

137 for line in self._stack: 

138 if line[0] == name: 

139 return True 

140 return False 

141 

142 def _get_last(self, name, info=None): 

143 if len(self._stack) == 0: 

144 raise RuntimeError("Stack is empty.") # pragma: no cover 

145 last = self._stack[-1] 

146 if ((isinstance(name, str) and last[0] != name) or 

147 (isinstance(name, tuple) and last[0] not in name)): 

148 raise RuntimeError( # pragma: no cover 

149 "Last item is not '{}'\n{}\n---\n{}".format( 

150 name, pprint.pformat(self._stack), 

151 pprint.pformat(info) if info else "")) 

152 return last 

153 

154 def make_msg(self, info): 

155 """ 

156 Make a message with line and column information. 

157 """ 

158 lineno = '?' 

159 col_offset = '?' 

160 if isinstance(info, dict): 

161 if 'node' in info: 

162 node = info['node'] 

163 lineno = node.lineno 

164 col_offset = node.col_offset 

165 else: 

166 if 'lineno' in info: 

167 lineno = info['lineno'] 

168 if 'col_offset' in info: 

169 col_offset = info['col_offset'] 

170 else: 

171 if hasattr(info, 'lineno'): 

172 lineno = info.lineno 

173 if hasattr(info, 'col_offset'): 

174 col_offset = info.col_offset 

175 

176 return "line {}, col {}".format(lineno, col_offset) 

177 

178 def export(self, context=None, format='code', # pylint: disable=W0221 

179 output_names=None): 

180 """ 

181 Returns an :epkg:`ONNX` graph or a piece 

182 of code which could generate the graph. 

183 

184 @param context function used in the function code 

185 @param format ``'code'`` 

186 @param output_names add code in the final function 

187 to overwrite the names of the 

188 outputs in the :epkg:`ONNX` graph 

189 @return string or :epkg:`onnx` graph 

190 

191 This method is used in function @see fn translate_fct2onnx. 

192 An example of code can be found there. 

193 """ 

194 if self._code_fct is None: 

195 raise RuntimeError( # pragma: no cover 

196 "No python code was parsed.") 

197 if context is None: 

198 context = {} 

199 

200 def find_onnx_correspondance(fct, info): 

201 if isinstance(fct, numpy.ufunc): 

202 name = fct.__name__ 

203 elif callable(fct) and getattr(fct, '__module__', '') in ( 

204 'numpy', 'numpy.core.fromnumeric'): 

205 name = fct.__name__ 

206 elif callable(fct) and fct.__name__.startswith("py_"): 

207 return fct 

208 else: 

209 name = None 

210 if name is not None and name not in OnnxTranslator._numpy2onnx_op: 

211 raise RuntimeError( # pragma: no cover 

212 "Unable to find a correspondance to '{}' at {} in \n{}".format( 

213 name, self.make_msg(info), 

214 "\n".join(sorted(OnnxTranslator._numpy2onnx_op)))) 

215 if name is not None: 

216 return OnnxTranslator._numpy2onnx_op[name] 

217 if isinstance(fct, str): 

218 return fct 

219 raise RuntimeError( # pragma: no cover 

220 "Unable to find a correspondance for function name '{}' in module '{}', " 

221 "'{}' (type {}) at {}.".format( 

222 name, getattr(fct, '__module__', ''), 

223 fct, type(fct), self.make_msg(info))) 

224 

225 def write_expression(stack_fct_used, expr, indent, parameter_mapping=None): 

226 if isinstance(expr, str): 

227 # an argument 

228 return ['{}{}'.format(" " * indent * 4, expr)] 

229 if isinstance(expr, (int, float)): 

230 # an argument 

231 return ['{}{}'.format(" " * indent * 4, expr)] 

232 if isinstance(expr, OnnxTranslator.Parameter): 

233 if parameter_mapping is None: 

234 name = expr.name 

235 else: 

236 name = parameter_mapping.get(expr.name, expr.name) 

237 return ["{}{}={}".format(" " * indent * 4, name, 

238 expr.formatted_value)] 

239 rows = [] 

240 if isinstance(expr, tuple): 

241 expr = [expr] 

242 for op, args in expr: 

243 if op == 'BinOp': 

244 opname = args["op"] 

245 opon = args["args"] 

246 onnx_name = OnnxTranslator._binary_operators[opname] 

247 rows.append( 

248 '{}Onnx{}('.format(" " * indent * 4, onnx_name)) 

249 for expr2 in opon: 

250 sexpr2 = write_expression( 

251 stack_fct_used, expr2, indent + 1) 

252 if any(filter(lambda s: 'op_version="op_version"' in s, sexpr2)): 

253 continue # pragma: no cover 

254 rows.extend(sexpr2) 

255 rows[-1] += "," 

256 rows.append('{}op_version=op_version'.format( 

257 " " * (indent + 1) * 4)) 

258 rows.append('{})'.format(" " * indent * 4)) 

259 elif op == 'UnaryOp': 

260 opname = args["op"] 

261 opon = args["args"] 

262 onnx_name = OnnxTranslator._unary_operators[opname] 

263 rows.append( 

264 '{}Onnx{}('.format(" " * indent * 4, onnx_name)) 

265 for expr2 in opon: 

266 sexpr2 = write_expression( 

267 stack_fct_used, expr2, indent + 1) 

268 if any(filter(lambda s: 'op_version="op_version"' in s, sexpr2)): 

269 continue 

270 rows.extend(sexpr2) 

271 rows[-1] += "," 

272 rows.append('{}op_version=op_version'.format( 

273 " " * (indent + 1) * 4)) 

274 rows.append('{})'.format(" " * indent * 4)) 

275 elif op == 'Call': 

276 name = args['name'] 

277 if name.startswith("onnx_"): 

278 raise RuntimeError("The code must not use a function prefixed by 'onnx_' (%s). " 

279 "It indicates that function manipulate ONNX node and " 

280 "the fonction to convert must only deal with arrays." % name) 

281 if name not in context: 

282 raise RuntimeError( 

283 "Unable to find function '{}' at {} in context\n{}\n--\n{}".format( 

284 name, self.make_msg(args), 

285 '\n'.join(sorted(context)), 

286 pprint.pformat(args))) 

287 op_conv = find_onnx_correspondance(context[name], args) 

288 if callable(op_conv) and op_conv.__name__.startswith('py_'): 

289 rows.append( 

290 '{}{}('.format(" " * indent * 4, op_conv.__name__)) 

291 elif callable(op_conv) and op_conv.__name__.startswith('onnx_'): 

292 stack_fct_used.append(op_conv.__name__) 

293 rows.append( 

294 '{}{}('.format(" " * indent * 4, op_conv)) 

295 else: 

296 prefix = "onnx_" if 'a' <= op_conv[0] <= 'z' else 'Onnx' 

297 if prefix == "onnx_": 

298 stack_fct_used.append( 

299 "{}{}".format(prefix, op_conv)) 

300 prefix = '_' + prefix 

301 rows.append( 

302 '{}{}{}('.format(" " * indent * 4, prefix, op_conv)) 

303 

304 opon = args["args"] 

305 opon = opon[1:] 

306 for expr2 in opon: 

307 sexpr2 = write_expression( 

308 stack_fct_used, expr2, indent + 1, 

309 OnnxTranslator._parameter_mapping.get(op_conv, None)) 

310 if any(filter(lambda s: 'op_version="op_version"' in s, sexpr2)): 

311 continue 

312 rows.extend(sexpr2) 

313 rows[-1] += "," 

314 rows.append('{}op_version=op_version'.format( 

315 " " * (indent + 1) * 4)) 

316 rows.append('{})'.format(" " * indent * 4)) 

317 else: 

318 raise RuntimeError( # pragma: no cover 

319 "Unable to interpret '{}'.".format(expr)) 

320 return rows 

321 

322 def write_function(stack_fct_used, to_replaces, node): 

323 rows = [] 

324 name, args = node 

325 if name != 'FunctionDef': 

326 raise RuntimeError( # pragma: no cover 

327 "The code being translated should be a single function not " 

328 "'{}' at {}.".format(name, self.make_msg(args))) 

329 list_args = list(map(str, args['args'])) 

330 if all(map(lambda s: 'dtype=' not in s, list_args)): 

331 list_args.append("dtype=numpy.float32") 

332 if all(map(lambda s: 'op_version=' not in s, list_args)): 

333 list_args.append("op_version=None") 

334 fct_name = args['name'] 

335 rows.append("def {}({}):".format( 

336 fct_name, ', '.join(list_args))) 

337 indent = 1 

338 

339 to_replace = "# __HEADER__{}".format(id(node)) 

340 to_replaces.append(to_replace) 

341 rows.append("{}{}".format(" " * (indent * 4), to_replace)) 

342 

343 code = args['code'] 

344 for op, args in code: 

345 if op == "Assign": 

346 name = args['name'] 

347 args = args["args"] 

348 rows.append("{}{} = (".format(" " * (indent * 4), name)) 

349 rows.extend(write_expression( 

350 stack_fct_used, args, indent + 1)) 

351 rows.append("{})".format(" " * (indent * 4))) 

352 elif op == "Return": 

353 args = args["code"] 

354 if output_names is None: 

355 rows.append("{}return (".format(" " * (indent * 4))) 

356 rows.extend(write_expression( 

357 stack_fct_used, args, indent + 1)) 

358 rows.append("{})".format(" " * (indent * 4))) 

359 else: 

360 rows.append( 

361 "{}return OnnxIdentity(".format(" " * (indent * 4))) 

362 subrows = write_expression( 

363 stack_fct_used, args, indent + 1) 

364 subrows[-1] += "," 

365 rows.extend(subrows) 

366 rows.append("{}output_names={},".format( 

367 " " * ((indent + 1) * 4), str(output_names))) 

368 rows.append("{}op_version=op_version".format( 

369 " " * ((indent + 1) * 4))) 

370 rows.append("{})".format(" " * (indent * 4))) 

371 else: 

372 raise RuntimeError( # pragma: no cover 

373 "Unable to process operator '{}' at {}. " 

374 "Make sure it is either an affectation, " 

375 "either a return.".format(op, self.make_msg(args))) 

376 return rows 

377 

378 stack_fct_used = [] 

379 to_replaces = [] 

380 rows = write_function(stack_fct_used, to_replaces, self._code_fct) 

381 

382 # handling dtype parameter 

383 if len(to_replaces) != 1: 

384 raise RuntimeError( # pragma: no cover 

385 "The following code misses a placeholder:\n{}".format( 

386 "\n".join(rows))) 

387 index = -1 

388 for i, row in enumerate(rows): 

389 if to_replaces[0] in row: 

390 index = i 

391 break 

392 

393 header = [] 

394 for fct in stack_fct_used: 

395 header.append( 

396 " _{0} = lambda *args, op_version=op_version, **kwargs: {0}(*args, dtype=dtype, " 

397 "op_version=op_version, **kwargs)".format(fct)) 

398 if len(header) > 0: 

399 header.append('') 

400 rows[index:index + 1] = header 

401 

402 return "\n".join(rows) 

403 

404 def visit(self, node, info): 

405 """ 

406 Visits a node. 

407 

408 @param node visited node 

409 @param info info extracted by the visitor 

410 """ 

411 if 'type' not in info: 

412 return 

413 

414 kind = info['type'] 

415 if kind == "Module": 

416 return 

417 if kind == "FunctionDef": 

418 if self._is_stacked('FunctionDef'): 

419 raise RuntimeError("Nested functions are not allowed at {}.".format( 

420 self.make_msg(node))) 

421 self._stack.append( 

422 ('FunctionDef', {'args': [], 'code': [], 'name': info['name'], 'default': [], 

423 'lineno': node.lineno, 'col_offset': node.col_offset})) 

424 return 

425 if kind == "arguments": 

426 _, buf = self._get_last('FunctionDef') 

427 return 

428 if kind == "arg": 

429 return 

430 if kind == "Assign": 

431 self._stack.append( 

432 ('Assign', {'args': [], 'lineno': node.lineno, 'col_offset': node.col_offset})) 

433 return 

434 if kind in ('Name', 'Cst'): 

435 self._get_last( 

436 ('Assign', 'BinOp', 'Call', 'Return', 'FunctionDef', 'keyword', 'UnaryOp')) 

437 return 

438 if kind == 'BinOp': 

439 self._stack.append( 

440 ('BinOp', {'args': [], 'lineno': node.lineno, 'col_offset': node.col_offset})) 

441 return 

442 if kind == 'UnaryOp': 

443 self._stack.append( 

444 ('UnaryOp', {'args': [], 'lineno': node.lineno, 'col_offset': node.col_offset})) 

445 return 

446 if kind in OnnxTranslator._binary_operators: 

447 _, buf = self._get_last(('BinOp', 'UnaryOp')) 

448 buf['op'] = kind 

449 return 

450 if kind == 'Call': 

451 self._stack.append( 

452 ('Call', {'name': info['str'], 'args': [], 'lineno': node.lineno, 

453 'col_offset': node.col_offset})) 

454 return 

455 if kind == 'Return': 

456 self._get_last('FunctionDef') 

457 self._stack.append( 

458 ('Return', {'code': [], 'lineno': node.lineno, 'col_offset': node.col_offset})) 

459 return 

460 if kind == "Attribute": 

461 if info.get('str', '') == 'T': 

462 raise NotImplementedError( # pragma: no cover 

463 "Transpose should be done with numpy.transpose not with .T'{}' " 

464 "at {}\n{}\n---\n{}".format( 

465 info.get('type', '?'), self.make_msg(node), 

466 pprint.pformat(info), pprint.pformat(self._stack))) 

467 self._get_last('Call') 

468 return 

469 if kind == 'keyword': 

470 self._get_last('Call') 

471 self._stack.append( 

472 ('keyword', {'name': "{0}".format(node.arg), 

473 'lineno': getattr(node, 'lineno', '?'), 

474 'col_offset': getattr(node, 'col_offset', '?')})) 

475 return 

476 if kind == 'List': 

477 self._get_last('keyword') 

478 self._stack.append( 

479 ('List', {'elts': [], 'lineno': getattr(node, 'lineno', '?'), 

480 'col_offset': getattr(node, 'col_offset', '?')})) 

481 return 

482 if kind == 'Num': 

483 self._get_last(('List', 'UnaryOp', 'BinOp', 'FunctionDef', 'Call')) 

484 return 

485 if kind == 'Str': 

486 self._get_last('keyword') 

487 return 

488 

489 raise NotImplementedError( # pragma: no cover 

490 "Unable to interpret kind '{}' at {}\n{}\n---\n{}".format( 

491 info.get('type', '?'), self.make_msg( 

492 node), pprint.pformat(info), 

493 pprint.pformat(self._stack))) 

494 

495 def _fix_default_values(self, code_fct): 

496 """ 

497 Maps default values with parameter names. 

498 """ 

499 nbdef = len(code_fct[1]['default']) 

500 nbpar = len(code_fct[1]['args']) 

501 args = [] 

502 for i in range(nbpar): 

503 name, annotation = code_fct[1]['args'][i] 

504 j = nbdef - (nbpar - i) 

505 if j >= 0: 

506 default = code_fct[1]['default'][j] 

507 p = OnnxTranslator.Parameter( 

508 name, annotation=annotation, value=default) 

509 else: 

510 p = OnnxTranslator.Parameter(name, annotation=annotation) 

511 args.append(p) 

512 code_fct[1]['args'] = args 

513 

514 def _post_process(self, op, node): 

515 """ 

516 Simplifies some operator such as ``OnnxNeg(2)``. 

517 """ 

518 if op is None and 'args' in node: 

519 for i in range(len(node['args'])): 

520 if not isinstance(node['args'][i], tuple): 

521 continue 

522 o, v = node['args'][i] 

523 if (o == 'UnaryOp' and len(v['args']) == 1 and 

524 isinstance(v['args'][0], (int, float, numpy.int64, 

525 numpy.float32, numpy.float64))): 

526 if v['op'] == 'Sub': 

527 node['args'][i] = -v['args'][0] 

528 

529 def depart(self, node, info): 

530 """ 

531 Visits a node. 

532 

533 @param node visited node 

534 @param info info extracted by the visitor 

535 """ 

536 if 'type' not in info: 

537 return 

538 

539 kind = info['type'] 

540 if kind == "arg": 

541 return 

542 if kind == "arguments": 

543 _, buf = self._get_last('FunctionDef') 

544 for child in info['children']: 

545 if child['type'] == 'Str': 

546 buf['default'].append(child['str']) 

547 elif child['type'] in ('Num', 'Cst'): 

548 buf['default'].append(child['n']) 

549 elif child['type'] == 'arg': 

550 buf['args'].append( 

551 (child['str'], child.get('annotation', None))) 

552 else: 

553 raise RuntimeError( # pragma: no cover 

554 "Unable to interpret type '{}' in function definition." 

555 "\n{}".format( 

556 child['type'], pprint.pformat(info))) 

557 return 

558 

559 if kind == "Name": 

560 op, buf = self._get_last( 

561 ('Assign', 'BinOp', 'Call', 'Return', 'FunctionDef', 'keyword', 

562 'UnaryOp'), 

563 info) 

564 if op == 'Assign': 

565 buf['name'] = info['str'] 

566 return 

567 elif op in ('BinOp', 'Call'): 

568 buf['args'].append(info['str']) 

569 return 

570 elif op == 'Return': 

571 buf['code'] = info['str'] 

572 return 

573 elif op == 'keyword': 

574 buf['value'] = info['str'] 

575 return 

576 elif op == 'UnaryOp': 

577 buf['args'].append(info['str']) 

578 return 

579 elif op == 'FunctionDef': 

580 raise RuntimeError("Default value must be constant, variable '{}' was " 

581 "detected.".format(info['str'])) 

582 

583 if kind in OnnxTranslator._binary_operators: 

584 _, buf = self._get_last(('BinOp', 'UnaryOp')) 

585 return 

586 if kind in ('Call', 'BinOp', 'Assign', 'Return', 'UnaryOp'): 

587 op, buf = self._get_last( 

588 ('Call', 'BinOp', 'Assign', 'Return', 'UnaryOp')) 

589 self._post_process(op, buf) 

590 self._stack.pop() 

591 opp, parent = self._get_last( 

592 ('Call', 'BinOp', 'Assign', 'FunctionDef', 'Return', 'UnaryOp')) 

593 if opp in ('FunctionDef', 'Return'): 

594 parent['code'].append((op, buf)) 

595 else: 

596 parent['args'].append((op, buf)) 

597 self._post_process(None, parent) 

598 return 

599 if kind == 'FunctionDef': 

600 if len(self._stack) == 1: 

601 self._code_fct = self._stack[-1] 

602 self._fix_default_values(self._code_fct) 

603 self._stack = [] 

604 return 

605 if kind == 'Module': 

606 return 

607 if kind == 'Attribute': 

608 op, buf = self._get_last(('Call', 'BinOp')) 

609 

610 if len(info["children"]) > 0: 

611 fir = info["children"][0] 

612 if fir["type"] == "Name": 

613 parent = fir["node"].id 

614 info["str"] = "{0}.{1}".format(parent, info["str"]) 

615 info["children"][0]["remove"] = True 

616 

617 buf['name'] = info["str"] 

618 buf['args'][0] = info["str"] 

619 return 

620 if kind in ('Num', 'Cst'): 

621 op, buf = self._get_last( 

622 ('List', 'BinOp', 'UnaryOp', 'FunctionDef', 'Call')) 

623 if op == 'FunctionDef': 

624 return 

625 if op == 'List': 

626 buf['elts'].append(info['n']) 

627 else: 

628 buf['args'].append(info['n']) 

629 return 

630 if kind == 'Str': 

631 _, buf = self._get_last('keyword') 

632 buf['value'] = info['str'] 

633 return 

634 if kind == 'List': 

635 op, buf = self._get_last('List') 

636 value = buf['elts'] 

637 self._post_process(op, buf) 

638 self._stack.pop() 

639 opp, parent = self._get_last('keyword') 

640 parent['value'] = value 

641 self._post_process(None, parent) 

642 return 

643 if kind == 'keyword': 

644 op, buf = self._get_last('keyword') 

645 name = buf["name"] 

646 if 'value' not in buf: 

647 raise RuntimeError(str(buf)) # pragma: no cover 

648 value = buf['value'] 

649 self._post_process(op, buf) 

650 self._stack.pop() 

651 opp, parent = self._get_last('Call') 

652 parent['args'].append(OnnxTranslator.Parameter(name, value)) 

653 self._post_process(None, parent) 

654 return 

655 

656 raise NotImplementedError( # pragma: no cover 

657 "Unable to interpret kind '{}' at {}\n{}\n---\n{}".format( 

658 info.get('type', '?'), self.make_msg( 

659 node), pprint.pformat(info), 

660 pprint.pformat(self._stack)))