Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- encoding: utf-8 -*- 

2""" 

3@file 

4@brief Shortcut to *ops_cpu*. 

5""" 

6import pprint 

7import numpy 

8import onnx 

9import onnx.defs 

10from ..shape_object import ShapeObject 

11from ..type_object import SequenceType 

12from ._new_ops import OperatorSchema 

13 

14 

15def _build_schemas(): 

16 res = {} 

17 for schema in onnx.defs.get_all_schemas_with_history(): 

18 # Multiple version can coexist. The last one is kept. 

19 if schema.name in res: 

20 if schema.since_version > res[schema.name].since_version: 

21 # We keep the most recent one. 

22 res[schema.name] = schema 

23 else: 

24 res[schema.name] = schema 

25 res[schema.name + '_' + str(schema.since_version)] = schema 

26 return res 

27 

28 

29_schemas = _build_schemas() 

30_at_least_one = {'Constant'} 

31 

32 

33class RuntimeTypeError(RuntimeError): 

34 """ 

35 Raised when a type of a variable is unexpected. 

36 """ 

37 pass 

38 

39 

40class DefaultNone: 

41 """ 

42 Default value for parameters when the parameter is not set 

43 but the operator has a default behaviour for it. 

44 """ 

45 pass 

46 

47 

48class OpRun: 

49 """ 

50 Ancestor to all operators in this subfolder. 

51 The runtime for every node can checked into 

52 `ONNX unit tests 

53 <https://github.com/onnx/onnx/tree/master/onnx/backend/test/case/node>`_. 

54 """ 

55 

56 def __init__(self, onnx_node, desc=None, expected_attributes=None, 

57 **options): 

58 """ 

59 @param onnx_node :epkg:`onnx` node 

60 @param desc internal representation 

61 @param expected_attributes expected attributes for this node 

62 @param options runtime options 

63 """ 

64 self._provider = 'python' 

65 self.onnx_node = onnx_node 

66 self.desc = desc 

67 self.inplaces = {} 

68 

69 if '_' in self.__class__.__name__: 

70 self._schema = _schemas.get(self.__class__.__name__, None) 

71 if self._schema is None: 

72 raise RuntimeError( # pragma: no cover 

73 "Unable to find class name '{}' in available schemas:" 

74 "(onnx.__version__='{}')\n{}".format( 

75 self.__class__.__name__, 

76 onnx.__version__, 

77 "\n".join(sorted(_schemas)))) 

78 elif onnx_node.op_type in _schemas: 

79 self._schema = _schemas[onnx_node.op_type] 

80 else: 

81 self._schema = self._find_custom_operator_schema(onnx_node.op_type) 

82 

83 if desc is not None: 

84 if 'atts' in desc: 

85 for a, b in desc['atts'].items(): 

86 if not isinstance(b, dict) or 'value' not in b: 

87 raise ValueError( # pragma: no cover 

88 "Unexpected value {}.".format(b)) 

89 options[a] = (b['value_rt'] if 'value_rt' in b 

90 else b['value']) 

91 if expected_attributes is not None: 

92 if onnx_node.op_type in _at_least_one: 

93 done = 0 

94 for a, b in expected_attributes.items(): 

95 if a in options: 

96 setattr(self, a, b) 

97 done += 1 

98 if done == 0: 

99 raise RuntimeError( # pragma: no cover 

100 "All parameters '{}' are missing from operator '{}', " 

101 "given {}.".format( 

102 a, onnx_node.op_type, list(sorted(options)))) 

103 else: 

104 for a, b in expected_attributes.items(): 

105 if a not in options: 

106 if b is DefaultNone: 

107 setattr(self, a, None) 

108 elif b is None: 

109 raise RuntimeError( # pragma: no cover 

110 "Parameter '{}' is missing from operator '{}', " 

111 "given {}.".format( 

112 a, onnx_node.op_type, list(sorted(options)))) 

113 else: 

114 setattr(self, a, b) 

115 for k, v in options.items(): 

116 setattr(self, k, v) 

117 

118 if onnx_node.op_type not in _at_least_one: 

119 for k, v in self._schema.attributes.items(): 

120 if not hasattr(self, k) and getattr(v, 'required', True): 

121 raise RuntimeError( # pragma: no cover 

122 "Attribute '{}' is expected based on ONNX specifications " 

123 "for node '{}' and options {}.".format( 

124 k, onnx_node.op_type, pprint.pformat(options))) 

125 

126 def need_context(self): 

127 """ 

128 Tells the runtime if this node needs the context 

129 (all the results produced so far) as it may silently access 

130 one of them (operator Loop). 

131 The default answer is `False`. 

132 """ 

133 return False 

134 

135 def _find_custom_operator_schema(self, op_name): 

136 raise NotImplementedError( # pragma: no cover 

137 "This method should be overwritten for operator " 

138 "'{}'.".format(op_name)) 

139 

140 def __str__(self): 

141 """ 

142 usual 

143 """ 

144 atts = [self.__class__.__name__ + '(', 

145 " op_type={}".format(self.onnx_node.op_type)] 

146 for k, v in sorted(self.__dict__.items()): 

147 if k in {'desc', 'onnx_node'}: 

148 continue 

149 if 'a' <= k[0] <= 'z' and k[-1] != '_': 

150 atts.append(' {0}={1},'.format(k, v)) 

151 atts.append(')') 

152 return "\n".join(atts) 

153 

154 def _run(self, *args, **kwargs): 

155 """ 

156 Should be overwritten. 

157 """ 

158 raise NotImplementedError( # pragma: no cover 

159 "This method should be overwritten.") 

160 

161 def run(self, *args, **kwargs): # pylint: disable=E0202 

162 """ 

163 Calls method ``_run``. 

164 """ 

165 try: 

166 res = self._run(*args, **kwargs) 

167 except TypeError as e: 

168 raise TypeError( # pragma: no cover 

169 "Issues with types {} (operator {}).".format( 

170 ", ".join(str(type(_)) for _ in args), 

171 self.__class__.__name__)) from e 

172 return res 

173 

174 def switch_initializers_dtype(self, dtype_in=numpy.float32, 

175 dtype_out=numpy.float64): 

176 """ 

177 Switches all initializers to ``numpy.float64``. If *model* 

178 is None, a simple cast is done. 

179 

180 @param dtype_in previous type 

181 @param dtype_out next type 

182 @return done operations 

183 """ 

184 done = [] 

185 for k, v in sorted(self.__dict__.items()): 

186 if k in {'desc', 'onnx_node'}: 

187 continue 

188 if isinstance(v, numpy.ndarray): 

189 if v.dtype == dtype_in: 

190 v = v.astype(dtype_out) 

191 setattr(self, k, v) 

192 done.append(("+", "att", k, getattr(self, k))) 

193 else: 

194 done.append(("-", "att", k, getattr(self, k))) 

195 if hasattr(self, '_run_no_checks_') and hasattr(self, 'run'): 

196 self.run = self._run_no_checks_ # pylint: disable=E0202,E1101 

197 return done 

198 

199 def infer_shapes(self, *args, **kwargs): 

200 """ 

201 Infer shapes of the outputs given the shapes 

202 of the inputs. It works the same way as method *run*. 

203 """ 

204 try: 

205 res = self._infer_shapes(*args, **kwargs) 

206 except TypeError as e: 

207 raise TypeError( 

208 "Issues with (operator '{}') and shapes\n{}" 

209 "\n----args\n{}\n------kwargs\n{}".format( 

210 self.__class__.__name__, 

211 "\n".join(str(_) for _ in args), 

212 pprint.pformat(args), 

213 pprint.pformat(kwargs))) from e 

214 if not isinstance(res, tuple): 

215 raise TypeError( # pragma: no cover 

216 "res must be tuple not {} (operator '{}')".format( 

217 type(res), self.__class__.__name__)) 

218 for a in res: 

219 if not isinstance(a, ShapeObject): 

220 raise TypeError( # pragma: no cover 

221 "One shape is not a ShapeObject but {} (operator '{}')".format( 

222 type(a), self.__class__.__name__)) 

223 return res 

224 

225 def _infer_shapes(self, *args, **kwargs): 

226 """ 

227 Should be overwritten. 

228 """ 

229 raise NotImplementedError( 

230 "This method should be overwritten for operator '{}'.".format( 

231 self.__class__.__name__)) # pragma: no cover 

232 

233 def infer_types(self, *args, **kwargs): 

234 """ 

235 Infer types of the outputs given the types 

236 of the inputs. It works the same way as method *run*. 

237 """ 

238 try: 

239 res = self._infer_types(*args, **kwargs) 

240 except TypeError as e: 

241 raise TypeError( 

242 "Issues with (operator '{}') and types\n{}" 

243 "\n----args\n{}\n------kwargs\n{}".format( 

244 self.__class__.__name__, 

245 "\n".join(str(_) for _ in args), 

246 pprint.pformat(args), 

247 pprint.pformat(kwargs))) from e 

248 if not isinstance(res, tuple): 

249 raise TypeError( # pragma: no cover 

250 "res must be tuple not {} (operator '{}')".format( 

251 type(res), self.__class__.__name__)) 

252 for a in res: 

253 if not isinstance(a, (numpy.dtype, SequenceType)) and a not in { 

254 numpy.int8, numpy.uint8, numpy.float16, numpy.float32, 

255 numpy.float64, numpy.int32, numpy.int64, numpy.int16, 

256 numpy.uint16, numpy.uint32, numpy.bool_, numpy.str_, 

257 numpy.uint64, bool, str}: 

258 raise TypeError( # pragma: no cover 

259 "Type ({}, {}) is not a numpy type or a sequence type " 

260 "(operator '{}')".format( 

261 a, type(a), self.__class__.__name__)) 

262 return res 

263 

264 def _infer_types(self, *args, **kwargs): 

265 """ 

266 Should be overwritten. 

267 """ 

268 raise NotImplementedError( 

269 "This method should be overwritten for operator '{}'.".format( 

270 self.__class__.__name__)) # pragma: no cover 

271 

272 def infer_sizes(self, *args, **kwargs): 

273 """ 

274 Infer sizes required for computation. 

275 It works the same way as method *run*. 

276 """ 

277 try: 

278 res = self._infer_sizes(*args, **kwargs) 

279 except TypeError as e: 

280 raise TypeError( 

281 "Issues with (operator '{}') and types\n{}" 

282 "\n----args\n{}\n------kwargs\n{}".format( 

283 self.__class__.__name__, 

284 "\n".join(str(_) for _ in args), 

285 pprint.pformat(args), 

286 pprint.pformat(kwargs))) from e 

287 if not isinstance(res, tuple): 

288 raise TypeError( # pragma: no cover 

289 "res must be dict not {} (operator '{}')".format( 

290 type(res), self.__class__.__name__)) 

291 return res 

292 

293 def _infer_sizes(self, *args, **kwargs): 

294 """ 

295 Should be overwritten. 

296 """ 

297 raise NotImplementedError( 

298 "This method should be overwritten for operator '{}'.".format( 

299 self.__class__.__name__)) # pragma: no cover 

300 

301 def enable_inplace_compute(self, index): 

302 """ 

303 Tells the node that one input can be overwritten. 

304 

305 @param index input index 

306 """ 

307 self.inplaces[index] = True 

308 

309 @property 

310 def args_default(self): 

311 """ 

312 Returns the list of arguments as well as 

313 the list of parameters with the default values 

314 (close to the signature). 

315 """ 

316 inps = [] 

317 if hasattr(self, 'atts'): 

318 for k, v in self.atts.items(): # pylint: disable=E1101 

319 if isinstance(v, (list, tuple, dict)) and len(v) == 0: 

320 v = None 

321 inps.append('%s=%r' % (k, v)) 

322 return inps 

323 

324 @property 

325 def args_default_modified(self): 

326 """ 

327 Returns the list of modified parameters. 

328 """ 

329 if not hasattr(self, 'atts'): 

330 return None 

331 

332 inps = [] 

333 for k, v in self.atts.items(): # pylint: disable=E1101 

334 val = getattr(self, k, None) 

335 if isinstance(val, numpy.ndarray) and isinstance(v, list): 

336 val = list(val) 

337 try: 

338 if val != v: 

339 inps.append('%s=%r' % (k, val)) 

340 except ValueError as e: 

341 raise ValueError( # pragma: no cover 

342 "Unexpected value for v=%r and val=%r." % (v, val)) from e 

343 return inps 

344 

345 @property 

346 def args_optional(self): 

347 """ 

348 Returns the list of optional arguments. 

349 """ 

350 inps = [] 

351 if hasattr(self, 'optional_inputs'): 

352 for k, v in self.optional_inputs.items(): # pylint: disable=E1101 

353 inps.append('%s=%r' % (k, v)) 

354 return inps 

355 

356 @property 

357 def args_mandatory(self): 

358 """ 

359 Returns the list of optional arguments. 

360 """ 

361 if hasattr(self, 'mandatory_inputs'): 

362 return self.mandatory_inputs # pylint: disable=E1101 

363 return None 

364 

365 def to_python(self, inputs): 

366 """ 

367 Returns a python code equivalent to this operator. 

368 

369 @param inputs inputs name 

370 @return imports, python code, both as strings 

371 """ 

372 raise NotImplementedError( 

373 "Operator '{}' has no equivalent python code.".format(self.__class__.__name__)) # pragma: no cover 

374 

375 def _to_python_numpy(self, inputs, numpy_name): 

376 return ("import numpy", 

377 "return numpy.%s(%s)" % (numpy_name, ", ".join(inputs))) 

378 

379 @property 

380 def atts_value(self): 

381 "Returns all parameters in a dictionary." 

382 if hasattr(self, 'atts'): 

383 return {k: getattr(self, k) 

384 for k in self.atts} # pylint: disable=E1101 

385 return None 

386 

387 

388class OpRunUnary(OpRun): 

389 """ 

390 Ancestor to all unary operators in this subfolder. 

391 Checks that inputs type are the same. 

392 """ 

393 

394 def __init__(self, onnx_node, desc=None, expected_attributes=None, 

395 **options): 

396 OpRun.__init__(self, onnx_node, desc=desc, 

397 expected_attributes=expected_attributes, 

398 **options) 

399 

400 def run(self, x): # pylint: disable=E0202,W0221 

401 """ 

402 Calls method ``_run``. 

403 """ 

404 try: 

405 res = self._run(x) 

406 except TypeError as e: 

407 raise TypeError( # pragma: no cover 

408 "Issues with types {} (binary operator {}).".format( 

409 ", ".join(str(type(_)) for _ in [x]), 

410 self.__class__.__name__)) from e 

411 return res 

412 

413 def infer_shapes(self, x): # pylint: disable=E0202,W0221 

414 try: 

415 return self._infer_shapes(x) 

416 except TypeError as e: # pragma: no cover 

417 raise TypeError( 

418 "Issues with types {} (operator {}).".format( 

419 x.dtype, self.__class__.__name__)) from e 

420 

421 def _infer_shapes(self, x): # pylint: disable=E0202,W0221 

422 """ 

423 Returns the same shape by default. 

424 """ 

425 return (x, ) 

426 

427 def infer_types(self, x): # pylint: disable=E0202,W0221 

428 try: 

429 return self._infer_types(x) 

430 except TypeError as e: # pragma: no cover 

431 raise TypeError( 

432 "Issues with types {} (operator {}).".format( 

433 x, self.__class__.__name__)) from e 

434 

435 def _infer_types(self, x): # pylint: disable=E0202,W0221 

436 """ 

437 Returns the same type by default. 

438 """ 

439 return (x, ) 

440 

441 def _infer_sizes(self, *args, **kwargs): 

442 res = self.run(*args, **kwargs) 

443 return (dict(temp=0), ) + res 

444 

445 

446class OpRunArg(OpRunUnary): 

447 """ 

448 Ancestor to all unary operators in this subfolder 

449 and which produces position of extremas (ArgMax, ...). 

450 Checks that inputs type are the same. 

451 The class must have attributes *axis*, *keepdim*. 

452 """ 

453 

454 def __init__(self, onnx_node, desc=None, expected_attributes=None, 

455 **options): 

456 OpRunUnary.__init__(self, onnx_node, desc=desc, 

457 expected_attributes=expected_attributes, 

458 **options) 

459 if not hasattr(self, 'keepdims'): 

460 raise AttributeError( # pragma: no cover 

461 "Attribute 'keepdims' is missing.") 

462 if not hasattr(self, 'axis'): 

463 raise AttributeError( # pragma: no cover 

464 "Attribute 'axis' is missing.") 

465 

466 def run(self, x): # pylint: disable=E0202 

467 """ 

468 Calls method ``_run``. 

469 """ 

470 res = OpRunUnary.run(self, x) 

471 if res[0].dtype != numpy.int64: 

472 raise RuntimeTypeError( # pragma: no cover 

473 "Output type mismatch: should be '{}' != output '{}' " 

474 "(operator '{}')".format( 

475 numpy.int64, res[0].dtype, self.__class__.__name__)) 

476 return res 

477 

478 def _infer_shapes(self, x): # pylint: disable=W0221 

479 sh = x.reduce(self.axis, self.keepdims, # pylint: disable=E1101 

480 dtype=numpy.int64) # pylint: disable=E1101 

481 return (sh, ) 

482 

483 def _infer_types(self, x): # pylint: disable=W0221 

484 return (numpy.int64, ) 

485 

486 def _run_no_checks_(self, x): # pylint: disable=W0221 

487 return OpRunUnary.run(self, x) 

488 

489 

490class OpRunUnaryNum(OpRunUnary): 

491 """ 

492 Ancestor to all unary and numerical operators 

493 in this subfolder. Checks that inputs type 

494 are the same. 

495 """ 

496 

497 def __init__(self, onnx_node, desc=None, expected_attributes=None, 

498 **options): 

499 OpRunUnary.__init__(self, onnx_node, desc=desc, 

500 expected_attributes=expected_attributes, 

501 **options) 

502 

503 def run(self, x): # pylint: disable=E0202 

504 """ 

505 Calls method ``_run``. 

506 """ 

507 res = OpRunUnary.run(self, x) 

508 if not isinstance(res[0], list) and res[0].dtype != x.dtype: 

509 raise RuntimeTypeError( # pragma: no cover 

510 "Output type mismatch: input '{}' != output '{}' " 

511 "(operator '{}')".format( 

512 x.dtype, res[0].dtype, self.__class__.__name__)) 

513 return res 

514 

515 def _run_no_checks_(self, x): # pylint: disable=W0221 

516 return OpRunUnary.run(self, x) 

517 

518 

519class OpRunClassifierProb(OpRunUnary): 

520 """ 

521 Ancestor to all binary operators in this subfolder. 

522 Checks that inputs type are the same. 

523 """ 

524 

525 def __init__(self, onnx_node, desc=None, expected_attributes=None, 

526 **options): 

527 OpRunUnary.__init__(self, onnx_node, desc=desc, 

528 expected_attributes=expected_attributes, 

529 **options) 

530 

531 def run(self, x): # pylint: disable=E0202 

532 """ 

533 Calls method ``_run``. 

534 """ 

535 res = OpRunUnary.run(self, x) 

536 if x.dtype in (numpy.float32, numpy.float64) and res[1].dtype != x.dtype: 

537 raise RuntimeTypeError( # pragma: no cover 

538 "Output type mismatch: {} != {} (operator '{}')".format( 

539 x.dtype, res[1].dtype, self.__class__.__name__)) 

540 return res 

541 

542 @property 

543 def nb_classes(self): 

544 """ 

545 Returns the number of expected classes. 

546 """ 

547 return max(len(getattr(self, 'classlabels_ints', [])), 

548 len(getattr(self, 'classlabels_int64s', [])), 

549 len(self.classlabels_strings)) # pylint: disable=E1101 

550 

551 def _run_no_checks_(self, x): # pylint: disable=W0221 

552 return OpRunUnary.run(self, x) 

553 

554 def _infer_shapes(self, x): # pylint: disable=W0221 

555 """ 

556 Returns the same for the labels and the probabilities. 

557 """ 

558 return (ShapeObject((x[0], ), dtype=numpy.int64, 

559 name="{}-0".format(self.__class__.__name__)), 

560 ShapeObject((x[0], self.nb_classes), dtype=x.dtype, 

561 name="{}-1".format(self.__class__.__name__))) 

562 

563 def _infer_types(self, x): # pylint: disable=W0221 

564 """ 

565 Returns the type of the labels and the probabilities. 

566 """ 

567 return (numpy.int64, x.dtype) 

568 

569 

570class OpRunBinary(OpRun): 

571 """ 

572 Ancestor to all binary operators in this subfolder. 

573 Checks that inputs type are the same. 

574 """ 

575 

576 def __init__(self, onnx_node, desc=None, expected_attributes=None, 

577 **options): 

578 OpRun.__init__(self, onnx_node, desc=desc, 

579 expected_attributes=expected_attributes, 

580 **options) 

581 

582 def run(self, x, y): # pylint: disable=E0202,W0221 

583 """ 

584 Calls method ``_run``. 

585 """ 

586 if x is None or y is None: 

587 raise RuntimeError("x and y have different dtype: {} != {} ({})".format( 

588 type(x), type(y), type(self))) 

589 if x.dtype != y.dtype: 

590 raise RuntimeTypeError( 

591 "Input type mismatch: {} != {} (operator '{}', shapes {}, {})".format( 

592 x.dtype, y.dtype, self.__class__.__name__, 

593 x.shape, y.shape)) 

594 try: 

595 res = self._run(x, y) 

596 except (TypeError, ValueError) as e: # pragma: no cover 

597 raise TypeError( 

598 "Issues with types {} (binary operator {}).".format( 

599 ", ".join(str(type(_)) for _ in [x, y]), 

600 self.__class__.__name__)) from e 

601 return res 

602 

603 def _run_no_checks_(self, x, y): # pylint: disable=W0221 

604 """ 

605 Calls method ``_run``. 

606 """ 

607 try: 

608 res = self._run(x, y) 

609 except TypeError as e: # pragma: no cover 

610 raise TypeError( 

611 "Issues with types {} (binary operator {}).".format( 

612 ", ".join(str(type(_)) for _ in [x, y]), 

613 self.__class__.__name__)) from e 

614 return res 

615 

616 def _infer_shapes(self, x, y): # pylint: disable=W0221 

617 """ 

618 Returns the same shape by default. 

619 We assume the operator returns the biggest 

620 shapes as the operator could be using broacasting. 

621 """ 

622 try: 

623 res = x.broadcast(y) 

624 add = "broadcast" 

625 except RuntimeError: # pragma: no cover 

626 # We know x and y and the same number of dimensions. 

627 # We pick the first one even if it might be wrong. 

628 res = x 

629 add = "1" 

630 if res.name is None: 

631 return (res.copy(name="{}{}".format( 

632 self.__class__.__name__, add)), ) 

633 return (res.copy(name="{}-{}{}".format( 

634 res.name, self.__class__.__name__, add)), ) 

635 

636 def _infer_types(self, x, y): # pylint: disable=W0221 

637 """ 

638 Returns the boolean type. 

639 """ 

640 return (x, ) 

641 

642 def _infer_sizes(self, *args, **kwargs): 

643 res = self.run(*args, **kwargs) 

644 return (dict(temp=0), ) + res 

645 

646 

647class OpRunBinaryComparison(OpRunBinary): 

648 """ 

649 Ancestor to all binary operators in this subfolder 

650 comparing tensors. 

651 """ 

652 

653 def __init__(self, onnx_node, desc=None, expected_attributes=None, 

654 **options): 

655 OpRunBinary.__init__(self, onnx_node, desc=desc, 

656 expected_attributes=expected_attributes, 

657 **options) 

658 

659 def _infer_types(self, x, y): # pylint: disable=W0221 

660 return (numpy.bool_, ) 

661 

662 

663class OpRunBinaryNum(OpRunBinary): 

664 """ 

665 Ancestor to all binary operators in this subfolder. 

666 Checks that inputs type are the same. 

667 """ 

668 

669 def __init__(self, onnx_node, desc=None, expected_attributes=None, 

670 **options): 

671 OpRunBinary.__init__(self, onnx_node, desc=desc, 

672 expected_attributes=expected_attributes, 

673 **options) 

674 

675 def run(self, x, y): # pylint: disable=E0202 

676 """ 

677 Calls method ``_run``. 

678 """ 

679 res = OpRunBinary.run(self, x, y) 

680 if res[0].dtype != x.dtype: 

681 raise RuntimeTypeError( 

682 "Output type mismatch: {} != {} (operator '{}')".format( 

683 x.dtype, res[0].dtype, self.__class__.__name__)) 

684 return res 

685 

686 def _run_no_checks_(self, x, y): # pylint: disable=W0221 

687 """ 

688 Calls method ``_run``. 

689 """ 

690 return OpRunBinary._run_no_checks_(self, x, y) 

691 

692 

693class OpRunBinaryNumpy(OpRunBinaryNum): 

694 """ 

695 Implements the inplaces logic. 

696 *numpy_fct* is a binary numpy function which 

697 takes two matrices and has a argument *out* 

698 for inplace operations. 

699 """ 

700 

701 def __init__(self, numpy_fct, onnx_node, desc=None, 

702 expected_attributes=None, **options): 

703 OpRunBinaryNum.__init__(self, onnx_node, desc=desc, 

704 expected_attributes=expected_attributes, 

705 **options) 

706 self.numpy_fct = numpy_fct 

707 self._cannot_inplace_int = self.numpy_fct in ( 

708 numpy.divide, numpy.true_divide) 

709 

710 def _run(self, a, b): # pylint: disable=W0221 

711 if (self._cannot_inplace_int and 

712 numpy.issubdtype(a.dtype, numpy.integer)): 

713 return (self.numpy_fct(a, b), ) 

714 if self.inplaces.get(0, False) and a.size >= b.size: 

715 if len(a.shape) == 1 and b.shape == (1, 1): 

716 a = a.reshape(1, a.shape[0]) 

717 try: 

718 self.numpy_fct(a, b, out=a) 

719 return (a, ) 

720 except (ValueError, TypeError): 

721 return (self.numpy_fct(a, b), ) 

722 if self.inplaces.get(1, False) and a.size <= b.size: 

723 if len(b.shape) == 1 and a.shape == (1, 1): 

724 b = b.reshape(b.shape[0], 1) 

725 try: 

726 self.numpy_fct(a, b, out=b) 

727 return (b, ) 

728 except (ValueError, TypeError): 

729 return (self.numpy_fct(a, b), ) 

730 return (self.numpy_fct(a, b), ) 

731 

732 def to_python(self, inputs): 

733 """ 

734 Returns a python code equivalent to this operator. 

735 

736 @param inputs inputs name 

737 @return imports, python code, both as strings 

738 """ 

739 lines = [ 

740 "# inplaces not take into account {}-{}".format( 

741 self.inplaces.get(0, False), self.inplaces.get(1, False)), 

742 "return numpy.{0}({1})".format( 

743 self.numpy_fct.__name__, ', '.join(inputs)) 

744 ] 

745 return "import numpy", "\n".join(lines) 

746 

747 

748class OpRunReduceNumpy(OpRunUnaryNum): 

749 """ 

750 Implements the reduce logic. 

751 It must have a parameter *axes*. 

752 """ 

753 

754 def __init__(self, onnx_node, desc=None, 

755 expected_attributes=None, **options): 

756 if ('noop_with_empty_axes' not in expected_attributes and 

757 'axes' not in expected_attributes): 

758 raise RuntimeError( # pragma: no cover 

759 "Parameter 'axes' is expected but not found in {} " 

760 "from class {}".format(expected_attributes, type(self))) 

761 if (expected_attributes.get('noop_with_empty_axes', 0) and 

762 (expected_attributes['axes'] is None or 

763 len(expected_attributes['axes']) == 0)): 

764 raise RuntimeError( # pragma: no cover 

765 "Parameter 'axes' cannot be empty as {} (noop_with_empty_axes=1) " 

766 "from class {}".format(expected_attributes, type(self))) 

767 OpRunUnaryNum.__init__(self, onnx_node, desc=desc, 

768 expected_attributes=expected_attributes, 

769 **options) 

770 if isinstance(self.axes, numpy.ndarray): # pylint: disable=E0203 

771 if (len(self.axes.shape) == 0 or # pylint: disable=E0203 

772 self.axes.shape[0] == 0): # pylint: disable=E0203 

773 self.axes = None 

774 else: 

775 self.axes = tuple(self.axes) 

776 elif self.axes in [[], tuple()]: # pylint: disable=E0203 

777 self.axes = None 

778 elif isinstance(self.axes, list): # pylint: disable=E0203 

779 self.axes = tuple(self.axes) 

780 

781 

782class OpRunCustom(OpRun): 

783 """ 

784 Automates some methods for custom operators defined 

785 outside *mlprodict*. 

786 """ 

787 

788 class OpRunCustomSchema(OperatorSchema): 

789 """ 

790 Custom schema. 

791 """ 

792 

793 def __init__(self, cls): 

794 OperatorSchema.__init__(self, cls.__name__) 

795 self.attributes = cls.atts 

796 

797 def __init__(self, onnx_node, desc=None, 

798 expected_attributes=None, **options): 

799 OpRun.__init__(self, onnx_node, desc=desc, 

800 expected_attributes=expected_attributes, 

801 **options) 

802 

803 def _find_custom_operator_schema(self, op_name): 

804 """ 

805 Finds a custom operator defined by this runtime. 

806 """ 

807 if (op_name == self.__class__.__name__ or 

808 (hasattr(self.__class__, 'op_name') and 

809 self.__class__.op_name == op_name)): # pylint: disable=E1101 

810 return OpRunCustom.OpRunCustomSchema(self.__class__) 

811 raise RuntimeError( # pragma: no cover 

812 "Unable to find a schema for operator '{}'.".format(op_name))