Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2"""
3@file
4@brief Shortcut to *ops_cpu*.
5"""
6import pprint
7import numpy
8import onnx
9import onnx.defs
10from ..shape_object import ShapeObject
11from ..type_object import SequenceType
12from ._new_ops import OperatorSchema
15def _build_schemas():
16 res = {}
17 for schema in onnx.defs.get_all_schemas_with_history():
18 # Multiple version can coexist. The last one is kept.
19 if schema.name in res:
20 if schema.since_version > res[schema.name].since_version:
21 # We keep the most recent one.
22 res[schema.name] = schema
23 else:
24 res[schema.name] = schema
25 res[schema.name + '_' + str(schema.since_version)] = schema
26 return res
29_schemas = _build_schemas()
30_at_least_one = {'Constant'}
33class RuntimeTypeError(RuntimeError):
34 """
35 Raised when a type of a variable is unexpected.
36 """
37 pass
40class DefaultNone:
41 """
42 Default value for parameters when the parameter is not set
43 but the operator has a default behaviour for it.
44 """
45 pass
48class OpRun:
49 """
50 Ancestor to all operators in this subfolder.
51 The runtime for every node can checked into
52 `ONNX unit tests
53 <https://github.com/onnx/onnx/tree/master/onnx/backend/test/case/node>`_.
54 """
56 def __init__(self, onnx_node, desc=None, expected_attributes=None,
57 **options):
58 """
59 @param onnx_node :epkg:`onnx` node
60 @param desc internal representation
61 @param expected_attributes expected attributes for this node
62 @param options runtime options
63 """
64 self._provider = 'python'
65 self.onnx_node = onnx_node
66 self.desc = desc
67 self.inplaces = {}
69 if '_' in self.__class__.__name__:
70 self._schema = _schemas.get(self.__class__.__name__, None)
71 if self._schema is None:
72 raise RuntimeError( # pragma: no cover
73 "Unable to find class name '{}' in available schemas:"
74 "(onnx.__version__='{}')\n{}".format(
75 self.__class__.__name__,
76 onnx.__version__,
77 "\n".join(sorted(_schemas))))
78 elif onnx_node.op_type in _schemas:
79 self._schema = _schemas[onnx_node.op_type]
80 else:
81 self._schema = self._find_custom_operator_schema(onnx_node.op_type)
83 if desc is not None:
84 if 'atts' in desc:
85 for a, b in desc['atts'].items():
86 if not isinstance(b, dict) or 'value' not in b:
87 raise ValueError( # pragma: no cover
88 "Unexpected value {}.".format(b))
89 options[a] = (b['value_rt'] if 'value_rt' in b
90 else b['value'])
91 if expected_attributes is not None:
92 if onnx_node.op_type in _at_least_one:
93 done = 0
94 for a, b in expected_attributes.items():
95 if a in options:
96 setattr(self, a, b)
97 done += 1
98 if done == 0:
99 raise RuntimeError( # pragma: no cover
100 "All parameters '{}' are missing from operator '{}', "
101 "given {}.".format(
102 a, onnx_node.op_type, list(sorted(options))))
103 else:
104 for a, b in expected_attributes.items():
105 if a not in options:
106 if b is DefaultNone:
107 setattr(self, a, None)
108 elif b is None:
109 raise RuntimeError( # pragma: no cover
110 "Parameter '{}' is missing from operator '{}', "
111 "given {}.".format(
112 a, onnx_node.op_type, list(sorted(options))))
113 else:
114 setattr(self, a, b)
115 for k, v in options.items():
116 setattr(self, k, v)
118 if onnx_node.op_type not in _at_least_one:
119 for k, v in self._schema.attributes.items():
120 if not hasattr(self, k) and getattr(v, 'required', True):
121 raise RuntimeError( # pragma: no cover
122 "Attribute '{}' is expected based on ONNX specifications "
123 "for node '{}' and options {}.".format(
124 k, onnx_node.op_type, pprint.pformat(options)))
126 def need_context(self):
127 """
128 Tells the runtime if this node needs the context
129 (all the results produced so far) as it may silently access
130 one of them (operator Loop).
131 The default answer is `False`.
132 """
133 return False
135 def _find_custom_operator_schema(self, op_name):
136 raise NotImplementedError( # pragma: no cover
137 "This method should be overwritten for operator "
138 "'{}'.".format(op_name))
140 def __str__(self):
141 """
142 usual
143 """
144 atts = [self.__class__.__name__ + '(',
145 " op_type={}".format(self.onnx_node.op_type)]
146 for k, v in sorted(self.__dict__.items()):
147 if k in {'desc', 'onnx_node'}:
148 continue
149 if 'a' <= k[0] <= 'z' and k[-1] != '_':
150 atts.append(' {0}={1},'.format(k, v))
151 atts.append(')')
152 return "\n".join(atts)
154 def _run(self, *args, **kwargs):
155 """
156 Should be overwritten.
157 """
158 raise NotImplementedError( # pragma: no cover
159 "This method should be overwritten.")
161 def run(self, *args, **kwargs): # pylint: disable=E0202
162 """
163 Calls method ``_run``.
164 """
165 try:
166 res = self._run(*args, **kwargs)
167 except TypeError as e:
168 raise TypeError( # pragma: no cover
169 "Issues with types {} (operator {}).".format(
170 ", ".join(str(type(_)) for _ in args),
171 self.__class__.__name__)) from e
172 return res
174 def switch_initializers_dtype(self, dtype_in=numpy.float32,
175 dtype_out=numpy.float64):
176 """
177 Switches all initializers to ``numpy.float64``. If *model*
178 is None, a simple cast is done.
180 @param dtype_in previous type
181 @param dtype_out next type
182 @return done operations
183 """
184 done = []
185 for k, v in sorted(self.__dict__.items()):
186 if k in {'desc', 'onnx_node'}:
187 continue
188 if isinstance(v, numpy.ndarray):
189 if v.dtype == dtype_in:
190 v = v.astype(dtype_out)
191 setattr(self, k, v)
192 done.append(("+", "att", k, getattr(self, k)))
193 else:
194 done.append(("-", "att", k, getattr(self, k)))
195 if hasattr(self, '_run_no_checks_') and hasattr(self, 'run'):
196 self.run = self._run_no_checks_ # pylint: disable=E0202,E1101
197 return done
199 def infer_shapes(self, *args, **kwargs):
200 """
201 Infer shapes of the outputs given the shapes
202 of the inputs. It works the same way as method *run*.
203 """
204 try:
205 res = self._infer_shapes(*args, **kwargs)
206 except TypeError as e:
207 raise TypeError(
208 "Issues with (operator '{}') and shapes\n{}"
209 "\n----args\n{}\n------kwargs\n{}".format(
210 self.__class__.__name__,
211 "\n".join(str(_) for _ in args),
212 pprint.pformat(args),
213 pprint.pformat(kwargs))) from e
214 if not isinstance(res, tuple):
215 raise TypeError( # pragma: no cover
216 "res must be tuple not {} (operator '{}')".format(
217 type(res), self.__class__.__name__))
218 for a in res:
219 if not isinstance(a, ShapeObject):
220 raise TypeError( # pragma: no cover
221 "One shape is not a ShapeObject but {} (operator '{}')".format(
222 type(a), self.__class__.__name__))
223 return res
225 def _infer_shapes(self, *args, **kwargs):
226 """
227 Should be overwritten.
228 """
229 raise NotImplementedError(
230 "This method should be overwritten for operator '{}'.".format(
231 self.__class__.__name__)) # pragma: no cover
233 def infer_types(self, *args, **kwargs):
234 """
235 Infer types of the outputs given the types
236 of the inputs. It works the same way as method *run*.
237 """
238 try:
239 res = self._infer_types(*args, **kwargs)
240 except TypeError as e:
241 raise TypeError(
242 "Issues with (operator '{}') and types\n{}"
243 "\n----args\n{}\n------kwargs\n{}".format(
244 self.__class__.__name__,
245 "\n".join(str(_) for _ in args),
246 pprint.pformat(args),
247 pprint.pformat(kwargs))) from e
248 if not isinstance(res, tuple):
249 raise TypeError( # pragma: no cover
250 "res must be tuple not {} (operator '{}')".format(
251 type(res), self.__class__.__name__))
252 for a in res:
253 if not isinstance(a, (numpy.dtype, SequenceType)) and a not in {
254 numpy.int8, numpy.uint8, numpy.float16, numpy.float32,
255 numpy.float64, numpy.int32, numpy.int64, numpy.int16,
256 numpy.uint16, numpy.uint32, numpy.bool_, numpy.str_,
257 numpy.uint64, bool, str}:
258 raise TypeError( # pragma: no cover
259 "Type ({}, {}) is not a numpy type or a sequence type "
260 "(operator '{}')".format(
261 a, type(a), self.__class__.__name__))
262 return res
264 def _infer_types(self, *args, **kwargs):
265 """
266 Should be overwritten.
267 """
268 raise NotImplementedError(
269 "This method should be overwritten for operator '{}'.".format(
270 self.__class__.__name__)) # pragma: no cover
272 def infer_sizes(self, *args, **kwargs):
273 """
274 Infer sizes required for computation.
275 It works the same way as method *run*.
276 """
277 try:
278 res = self._infer_sizes(*args, **kwargs)
279 except TypeError as e:
280 raise TypeError(
281 "Issues with (operator '{}') and types\n{}"
282 "\n----args\n{}\n------kwargs\n{}".format(
283 self.__class__.__name__,
284 "\n".join(str(_) for _ in args),
285 pprint.pformat(args),
286 pprint.pformat(kwargs))) from e
287 if not isinstance(res, tuple):
288 raise TypeError( # pragma: no cover
289 "res must be dict not {} (operator '{}')".format(
290 type(res), self.__class__.__name__))
291 return res
293 def _infer_sizes(self, *args, **kwargs):
294 """
295 Should be overwritten.
296 """
297 raise NotImplementedError(
298 "This method should be overwritten for operator '{}'.".format(
299 self.__class__.__name__)) # pragma: no cover
301 def enable_inplace_compute(self, index):
302 """
303 Tells the node that one input can be overwritten.
305 @param index input index
306 """
307 self.inplaces[index] = True
309 @property
310 def args_default(self):
311 """
312 Returns the list of arguments as well as
313 the list of parameters with the default values
314 (close to the signature).
315 """
316 inps = []
317 if hasattr(self, 'atts'):
318 for k, v in self.atts.items(): # pylint: disable=E1101
319 if isinstance(v, (list, tuple, dict)) and len(v) == 0:
320 v = None
321 inps.append('%s=%r' % (k, v))
322 return inps
324 @property
325 def args_default_modified(self):
326 """
327 Returns the list of modified parameters.
328 """
329 if not hasattr(self, 'atts'):
330 return None
332 inps = []
333 for k, v in self.atts.items(): # pylint: disable=E1101
334 val = getattr(self, k, None)
335 if isinstance(val, numpy.ndarray) and isinstance(v, list):
336 val = list(val)
337 try:
338 if val != v:
339 inps.append('%s=%r' % (k, val))
340 except ValueError as e:
341 raise ValueError( # pragma: no cover
342 "Unexpected value for v=%r and val=%r." % (v, val)) from e
343 return inps
345 @property
346 def args_optional(self):
347 """
348 Returns the list of optional arguments.
349 """
350 inps = []
351 if hasattr(self, 'optional_inputs'):
352 for k, v in self.optional_inputs.items(): # pylint: disable=E1101
353 inps.append('%s=%r' % (k, v))
354 return inps
356 @property
357 def args_mandatory(self):
358 """
359 Returns the list of optional arguments.
360 """
361 if hasattr(self, 'mandatory_inputs'):
362 return self.mandatory_inputs # pylint: disable=E1101
363 return None
365 def to_python(self, inputs):
366 """
367 Returns a python code equivalent to this operator.
369 @param inputs inputs name
370 @return imports, python code, both as strings
371 """
372 raise NotImplementedError(
373 "Operator '{}' has no equivalent python code.".format(self.__class__.__name__)) # pragma: no cover
375 def _to_python_numpy(self, inputs, numpy_name):
376 return ("import numpy",
377 "return numpy.%s(%s)" % (numpy_name, ", ".join(inputs)))
379 @property
380 def atts_value(self):
381 "Returns all parameters in a dictionary."
382 if hasattr(self, 'atts'):
383 return {k: getattr(self, k)
384 for k in self.atts} # pylint: disable=E1101
385 return None
388class OpRunUnary(OpRun):
389 """
390 Ancestor to all unary operators in this subfolder.
391 Checks that inputs type are the same.
392 """
394 def __init__(self, onnx_node, desc=None, expected_attributes=None,
395 **options):
396 OpRun.__init__(self, onnx_node, desc=desc,
397 expected_attributes=expected_attributes,
398 **options)
400 def run(self, x): # pylint: disable=E0202,W0221
401 """
402 Calls method ``_run``.
403 """
404 try:
405 res = self._run(x)
406 except TypeError as e:
407 raise TypeError( # pragma: no cover
408 "Issues with types {} (binary operator {}).".format(
409 ", ".join(str(type(_)) for _ in [x]),
410 self.__class__.__name__)) from e
411 return res
413 def infer_shapes(self, x): # pylint: disable=E0202,W0221
414 try:
415 return self._infer_shapes(x)
416 except TypeError as e: # pragma: no cover
417 raise TypeError(
418 "Issues with types {} (operator {}).".format(
419 x.dtype, self.__class__.__name__)) from e
421 def _infer_shapes(self, x): # pylint: disable=E0202,W0221
422 """
423 Returns the same shape by default.
424 """
425 return (x, )
427 def infer_types(self, x): # pylint: disable=E0202,W0221
428 try:
429 return self._infer_types(x)
430 except TypeError as e: # pragma: no cover
431 raise TypeError(
432 "Issues with types {} (operator {}).".format(
433 x, self.__class__.__name__)) from e
435 def _infer_types(self, x): # pylint: disable=E0202,W0221
436 """
437 Returns the same type by default.
438 """
439 return (x, )
441 def _infer_sizes(self, *args, **kwargs):
442 res = self.run(*args, **kwargs)
443 return (dict(temp=0), ) + res
446class OpRunArg(OpRunUnary):
447 """
448 Ancestor to all unary operators in this subfolder
449 and which produces position of extremas (ArgMax, ...).
450 Checks that inputs type are the same.
451 The class must have attributes *axis*, *keepdim*.
452 """
454 def __init__(self, onnx_node, desc=None, expected_attributes=None,
455 **options):
456 OpRunUnary.__init__(self, onnx_node, desc=desc,
457 expected_attributes=expected_attributes,
458 **options)
459 if not hasattr(self, 'keepdims'):
460 raise AttributeError( # pragma: no cover
461 "Attribute 'keepdims' is missing.")
462 if not hasattr(self, 'axis'):
463 raise AttributeError( # pragma: no cover
464 "Attribute 'axis' is missing.")
466 def run(self, x): # pylint: disable=E0202
467 """
468 Calls method ``_run``.
469 """
470 res = OpRunUnary.run(self, x)
471 if res[0].dtype != numpy.int64:
472 raise RuntimeTypeError( # pragma: no cover
473 "Output type mismatch: should be '{}' != output '{}' "
474 "(operator '{}')".format(
475 numpy.int64, res[0].dtype, self.__class__.__name__))
476 return res
478 def _infer_shapes(self, x): # pylint: disable=W0221
479 sh = x.reduce(self.axis, self.keepdims, # pylint: disable=E1101
480 dtype=numpy.int64) # pylint: disable=E1101
481 return (sh, )
483 def _infer_types(self, x): # pylint: disable=W0221
484 return (numpy.int64, )
486 def _run_no_checks_(self, x): # pylint: disable=W0221
487 return OpRunUnary.run(self, x)
490class OpRunUnaryNum(OpRunUnary):
491 """
492 Ancestor to all unary and numerical operators
493 in this subfolder. Checks that inputs type
494 are the same.
495 """
497 def __init__(self, onnx_node, desc=None, expected_attributes=None,
498 **options):
499 OpRunUnary.__init__(self, onnx_node, desc=desc,
500 expected_attributes=expected_attributes,
501 **options)
503 def run(self, x): # pylint: disable=E0202
504 """
505 Calls method ``_run``.
506 """
507 res = OpRunUnary.run(self, x)
508 if not isinstance(res[0], list) and res[0].dtype != x.dtype:
509 raise RuntimeTypeError( # pragma: no cover
510 "Output type mismatch: input '{}' != output '{}' "
511 "(operator '{}')".format(
512 x.dtype, res[0].dtype, self.__class__.__name__))
513 return res
515 def _run_no_checks_(self, x): # pylint: disable=W0221
516 return OpRunUnary.run(self, x)
519class OpRunClassifierProb(OpRunUnary):
520 """
521 Ancestor to all binary operators in this subfolder.
522 Checks that inputs type are the same.
523 """
525 def __init__(self, onnx_node, desc=None, expected_attributes=None,
526 **options):
527 OpRunUnary.__init__(self, onnx_node, desc=desc,
528 expected_attributes=expected_attributes,
529 **options)
531 def run(self, x): # pylint: disable=E0202
532 """
533 Calls method ``_run``.
534 """
535 res = OpRunUnary.run(self, x)
536 if x.dtype in (numpy.float32, numpy.float64) and res[1].dtype != x.dtype:
537 raise RuntimeTypeError( # pragma: no cover
538 "Output type mismatch: {} != {} (operator '{}')".format(
539 x.dtype, res[1].dtype, self.__class__.__name__))
540 return res
542 @property
543 def nb_classes(self):
544 """
545 Returns the number of expected classes.
546 """
547 return max(len(getattr(self, 'classlabels_ints', [])),
548 len(getattr(self, 'classlabels_int64s', [])),
549 len(self.classlabels_strings)) # pylint: disable=E1101
551 def _run_no_checks_(self, x): # pylint: disable=W0221
552 return OpRunUnary.run(self, x)
554 def _infer_shapes(self, x): # pylint: disable=W0221
555 """
556 Returns the same for the labels and the probabilities.
557 """
558 return (ShapeObject((x[0], ), dtype=numpy.int64,
559 name="{}-0".format(self.__class__.__name__)),
560 ShapeObject((x[0], self.nb_classes), dtype=x.dtype,
561 name="{}-1".format(self.__class__.__name__)))
563 def _infer_types(self, x): # pylint: disable=W0221
564 """
565 Returns the type of the labels and the probabilities.
566 """
567 return (numpy.int64, x.dtype)
570class OpRunBinary(OpRun):
571 """
572 Ancestor to all binary operators in this subfolder.
573 Checks that inputs type are the same.
574 """
576 def __init__(self, onnx_node, desc=None, expected_attributes=None,
577 **options):
578 OpRun.__init__(self, onnx_node, desc=desc,
579 expected_attributes=expected_attributes,
580 **options)
582 def run(self, x, y): # pylint: disable=E0202,W0221
583 """
584 Calls method ``_run``.
585 """
586 if x is None or y is None:
587 raise RuntimeError("x and y have different dtype: {} != {} ({})".format(
588 type(x), type(y), type(self)))
589 if x.dtype != y.dtype:
590 raise RuntimeTypeError(
591 "Input type mismatch: {} != {} (operator '{}', shapes {}, {})".format(
592 x.dtype, y.dtype, self.__class__.__name__,
593 x.shape, y.shape))
594 try:
595 res = self._run(x, y)
596 except (TypeError, ValueError) as e: # pragma: no cover
597 raise TypeError(
598 "Issues with types {} (binary operator {}).".format(
599 ", ".join(str(type(_)) for _ in [x, y]),
600 self.__class__.__name__)) from e
601 return res
603 def _run_no_checks_(self, x, y): # pylint: disable=W0221
604 """
605 Calls method ``_run``.
606 """
607 try:
608 res = self._run(x, y)
609 except TypeError as e: # pragma: no cover
610 raise TypeError(
611 "Issues with types {} (binary operator {}).".format(
612 ", ".join(str(type(_)) for _ in [x, y]),
613 self.__class__.__name__)) from e
614 return res
616 def _infer_shapes(self, x, y): # pylint: disable=W0221
617 """
618 Returns the same shape by default.
619 We assume the operator returns the biggest
620 shapes as the operator could be using broacasting.
621 """
622 try:
623 res = x.broadcast(y)
624 add = "broadcast"
625 except RuntimeError: # pragma: no cover
626 # We know x and y and the same number of dimensions.
627 # We pick the first one even if it might be wrong.
628 res = x
629 add = "1"
630 if res.name is None:
631 return (res.copy(name="{}{}".format(
632 self.__class__.__name__, add)), )
633 return (res.copy(name="{}-{}{}".format(
634 res.name, self.__class__.__name__, add)), )
636 def _infer_types(self, x, y): # pylint: disable=W0221
637 """
638 Returns the boolean type.
639 """
640 return (x, )
642 def _infer_sizes(self, *args, **kwargs):
643 res = self.run(*args, **kwargs)
644 return (dict(temp=0), ) + res
647class OpRunBinaryComparison(OpRunBinary):
648 """
649 Ancestor to all binary operators in this subfolder
650 comparing tensors.
651 """
653 def __init__(self, onnx_node, desc=None, expected_attributes=None,
654 **options):
655 OpRunBinary.__init__(self, onnx_node, desc=desc,
656 expected_attributes=expected_attributes,
657 **options)
659 def _infer_types(self, x, y): # pylint: disable=W0221
660 return (numpy.bool_, )
663class OpRunBinaryNum(OpRunBinary):
664 """
665 Ancestor to all binary operators in this subfolder.
666 Checks that inputs type are the same.
667 """
669 def __init__(self, onnx_node, desc=None, expected_attributes=None,
670 **options):
671 OpRunBinary.__init__(self, onnx_node, desc=desc,
672 expected_attributes=expected_attributes,
673 **options)
675 def run(self, x, y): # pylint: disable=E0202
676 """
677 Calls method ``_run``.
678 """
679 res = OpRunBinary.run(self, x, y)
680 if res[0].dtype != x.dtype:
681 raise RuntimeTypeError(
682 "Output type mismatch: {} != {} (operator '{}')".format(
683 x.dtype, res[0].dtype, self.__class__.__name__))
684 return res
686 def _run_no_checks_(self, x, y): # pylint: disable=W0221
687 """
688 Calls method ``_run``.
689 """
690 return OpRunBinary._run_no_checks_(self, x, y)
693class OpRunBinaryNumpy(OpRunBinaryNum):
694 """
695 Implements the inplaces logic.
696 *numpy_fct* is a binary numpy function which
697 takes two matrices and has a argument *out*
698 for inplace operations.
699 """
701 def __init__(self, numpy_fct, onnx_node, desc=None,
702 expected_attributes=None, **options):
703 OpRunBinaryNum.__init__(self, onnx_node, desc=desc,
704 expected_attributes=expected_attributes,
705 **options)
706 self.numpy_fct = numpy_fct
707 self._cannot_inplace_int = self.numpy_fct in (
708 numpy.divide, numpy.true_divide)
710 def _run(self, a, b): # pylint: disable=W0221
711 if (self._cannot_inplace_int and
712 numpy.issubdtype(a.dtype, numpy.integer)):
713 return (self.numpy_fct(a, b), )
714 if self.inplaces.get(0, False) and a.size >= b.size:
715 if len(a.shape) == 1 and b.shape == (1, 1):
716 a = a.reshape(1, a.shape[0])
717 try:
718 self.numpy_fct(a, b, out=a)
719 return (a, )
720 except (ValueError, TypeError):
721 return (self.numpy_fct(a, b), )
722 if self.inplaces.get(1, False) and a.size <= b.size:
723 if len(b.shape) == 1 and a.shape == (1, 1):
724 b = b.reshape(b.shape[0], 1)
725 try:
726 self.numpy_fct(a, b, out=b)
727 return (b, )
728 except (ValueError, TypeError):
729 return (self.numpy_fct(a, b), )
730 return (self.numpy_fct(a, b), )
732 def to_python(self, inputs):
733 """
734 Returns a python code equivalent to this operator.
736 @param inputs inputs name
737 @return imports, python code, both as strings
738 """
739 lines = [
740 "# inplaces not take into account {}-{}".format(
741 self.inplaces.get(0, False), self.inplaces.get(1, False)),
742 "return numpy.{0}({1})".format(
743 self.numpy_fct.__name__, ', '.join(inputs))
744 ]
745 return "import numpy", "\n".join(lines)
748class OpRunReduceNumpy(OpRunUnaryNum):
749 """
750 Implements the reduce logic.
751 It must have a parameter *axes*.
752 """
754 def __init__(self, onnx_node, desc=None,
755 expected_attributes=None, **options):
756 if ('noop_with_empty_axes' not in expected_attributes and
757 'axes' not in expected_attributes):
758 raise RuntimeError( # pragma: no cover
759 "Parameter 'axes' is expected but not found in {} "
760 "from class {}".format(expected_attributes, type(self)))
761 if (expected_attributes.get('noop_with_empty_axes', 0) and
762 (expected_attributes['axes'] is None or
763 len(expected_attributes['axes']) == 0)):
764 raise RuntimeError( # pragma: no cover
765 "Parameter 'axes' cannot be empty as {} (noop_with_empty_axes=1) "
766 "from class {}".format(expected_attributes, type(self)))
767 OpRunUnaryNum.__init__(self, onnx_node, desc=desc,
768 expected_attributes=expected_attributes,
769 **options)
770 if isinstance(self.axes, numpy.ndarray): # pylint: disable=E0203
771 if (len(self.axes.shape) == 0 or # pylint: disable=E0203
772 self.axes.shape[0] == 0): # pylint: disable=E0203
773 self.axes = None
774 else:
775 self.axes = tuple(self.axes)
776 elif self.axes in [[], tuple()]: # pylint: disable=E0203
777 self.axes = None
778 elif isinstance(self.axes, list): # pylint: disable=E0203
779 self.axes = tuple(self.axes)
782class OpRunCustom(OpRun):
783 """
784 Automates some methods for custom operators defined
785 outside *mlprodict*.
786 """
788 class OpRunCustomSchema(OperatorSchema):
789 """
790 Custom schema.
791 """
793 def __init__(self, cls):
794 OperatorSchema.__init__(self, cls.__name__)
795 self.attributes = cls.atts
797 def __init__(self, onnx_node, desc=None,
798 expected_attributes=None, **options):
799 OpRun.__init__(self, onnx_node, desc=desc,
800 expected_attributes=expected_attributes,
801 **options)
803 def _find_custom_operator_schema(self, op_name):
804 """
805 Finds a custom operator defined by this runtime.
806 """
807 if (op_name == self.__class__.__name__ or
808 (hasattr(self.__class__, 'op_name') and
809 self.__class__.op_name == op_name)): # pylint: disable=E1101
810 return OpRunCustom.OpRunCustomSchema(self.__class__)
811 raise RuntimeError( # pragma: no cover
812 "Unable to find a schema for operator '{}'.".format(op_name))