Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from onnx.defs import onnx_opset_version
9from ._op import OpRunReduceNumpy, RuntimeTypeError, OpRun
12class ReduceSum_1(OpRunReduceNumpy):
14 atts = {'axes': [], 'keepdims': 1}
16 def __init__(self, onnx_node, desc=None, **options):
17 OpRunReduceNumpy.__init__(self, onnx_node, desc=desc,
18 expected_attributes=ReduceSum_1.atts,
19 **options)
21 def _run(self, data): # pylint: disable=W0221
22 return (numpy.sum(data, axis=self.axes,
23 keepdims=self.keepdims,
24 dtype=data.dtype), )
27class ReduceSum_11(ReduceSum_1):
29 def __init__(self, onnx_node, desc=None, **options):
30 ReduceSum_1.__init__(self, onnx_node, desc=desc, **options)
33class ReduceSum_13(OpRunReduceNumpy):
35 atts = {'axes': [], 'keepdims': 1, 'noop_with_empty_axes': 0}
37 def __init__(self, onnx_node, desc=None, **options):
38 OpRunReduceNumpy.__init__(self, onnx_node, desc=desc,
39 expected_attributes=ReduceSum_13.atts,
40 **options)
42 def run(self, data, axes=None): # pylint: disable=E0202,W0221
43 """
44 Calls method ``_run``.
45 """
46 res = self._run(data, axes=axes)
47 if not self.keepdims and not isinstance(res[0], numpy.ndarray):
48 res = (numpy.array([res[0]], dtype=res[0].dtype), )
49 if res[0].dtype != data.dtype:
50 raise RuntimeTypeError( # pragma: no cover
51 "Output type mismatch: input '{}' != output '{}' "
52 "(operator '{}')".format(
53 data.dtype, res[0].dtype, self.__class__.__name__))
54 return res
56 def _run_no_checks_(self, x, axes=None): # pylint: disable=W0221
57 return OpRun.run(self, x, axes)
59 def _run(self, data, axes=None): # pylint: disable=W0221
60 if axes is None and self.noop_with_empty_axes:
61 return (data, )
62 if axes is not None and not isinstance(axes, int):
63 if isinstance(axes, numpy.ndarray) and len(axes.shape) == 0:
64 axes = int(axes)
65 else:
66 axes = tuple(axes) if len(axes) > 0 else None
67 try:
68 return (numpy.sum(data, axis=axes,
69 keepdims=self.keepdims,
70 dtype=data.dtype), )
71 except TypeError as e: # pragma: no cover
72 raise TypeError(
73 "Unable to reduce shape %r with axes=%r." % (
74 data.shape, axes)) from e
76 def infer_shapes(self, data, axes=None): # pylint: disable=E0202,W0221
77 return self._infer_shapes(data, axes=axes)
79 def _infer_shapes(self, data, axes=None): # pylint: disable=W0221
80 """
81 Returns the same shape by default.
82 """
83 sh = data.reduce(axes, self.keepdims, # pylint: disable=E1101
84 dtype=numpy.int64) # pylint: disable=E1101
85 return (sh, )
88if onnx_opset_version() >= 13:
89 ReduceSum = ReduceSum_13
90elif onnx_opset_version() >= 11: # pragma: no cover
91 ReduceSum = ReduceSum_11
92else: # pragma: no cover
93 ReduceSum = ReduceSum_1