# -*- encoding: utf-8 -*-
# pylint: disable=E0203,E1101,C0111
"""
Runtime operator.
:githublink:`%|py|7`
"""
import numpy
from onnx.defs import onnx_opset_version
from ._op import OpRunReduceNumpy, RuntimeTypeError, OpRun
[docs]class ReduceSum_1(OpRunReduceNumpy):
atts = {'axes': [], 'keepdims': 1}
[docs] def __init__(self, onnx_node, desc=None, **options):
OpRunReduceNumpy.__init__(self, onnx_node, desc=desc,
expected_attributes=ReduceSum_1.atts,
**options)
[docs] def _run(self, data): # pylint: disable=W0221
return (numpy.sum(data, axis=self.axes,
keepdims=self.keepdims,
dtype=data.dtype), )
[docs]class ReduceSum_11(ReduceSum_1):
[docs] def __init__(self, onnx_node, desc=None, **options):
ReduceSum_1.__init__(self, onnx_node, desc=desc, **options)
[docs]class ReduceSum_13(OpRunReduceNumpy):
atts = {'axes': [], 'keepdims': 1, 'noop_with_empty_axes': 0}
[docs] def __init__(self, onnx_node, desc=None, **options):
OpRunReduceNumpy.__init__(self, onnx_node, desc=desc,
expected_attributes=ReduceSum_13.atts,
**options)
[docs] def run(self, data, axes=None): # pylint: disable=E0202,W0221
"""
Calls method ``_run``.
:githublink:`%|py|45`
"""
res = self._run(data, axes=axes)
if res[0].dtype != data.dtype:
raise RuntimeTypeError( # pragma: no cover
"Output type mismatch: input '{}' != output '{}' "
"(operator '{}')".format(
data.dtype, res[0].dtype, self.__class__.__name__))
return res
[docs] def _run_no_checks_(self, x, axes=None): # pylint: disable=W0221
return OpRun.run(self, x, axes)
[docs] def _run(self, data, axes=None): # pylint: disable=W0221
if axes is None and self.noop_with_empty_axes:
return (data, )
if axes is not None and not isinstance(axes, int):
if isinstance(axes, numpy.ndarray) and len(axes.shape) == 0:
axes = int(axes)
else:
axes = tuple(axes) if len(axes) > 0 else None
try:
return (numpy.sum(data, axis=axes,
keepdims=self.keepdims,
dtype=data.dtype), )
except TypeError as e:
raise TypeError(
"Unable to reduce shape %r with axes=%r." % (
data.shape, axes)) from e
[docs] def infer_shapes(self, data, axes=None): # pylint: disable=E0202,W0221
return self._infer_shapes(data, axes=axes)
[docs] def _infer_shapes(self, data, axes=None): # pylint: disable=W0221
"""
Returns the same shape by default.
:githublink:`%|py|80`
"""
sh = data.reduce(axes, self.keepdims, # pylint: disable=E1101
dtype=numpy.int64) # pylint: disable=E1101
return (sh, )
if onnx_opset_version() >= 13:
ReduceSum = ReduceSum_13
elif onnx_opset_version() >= 11:
ReduceSum = ReduceSum_11
else:
ReduceSum = ReduceSum_1