Source code for mlprodict.onnxrt.ops_cpu.op_squeeze

# -*- encoding: utf-8 -*-
# pylint: disable=E0203,E1101,C0111
"""
Runtime operator.


:githublink:`%|py|7`
"""
import numpy
from onnx.defs import onnx_opset_version
from ..shape_object import ShapeObject
from ._op import OpRunUnaryNum, OpRun


[docs]class Squeeze_1(OpRunUnaryNum): atts = {'axes': [], 'keepdims': 1}
[docs] def __init__(self, onnx_node, desc=None, **options): OpRunUnaryNum.__init__(self, onnx_node, desc=desc, expected_attributes=Squeeze_1.atts, **options) if isinstance(self.axes, numpy.ndarray): self.axes = tuple(self.axes) elif self.axes in [[], tuple()]: self.axes = None elif isinstance(self.axes, list): self.axes = tuple(self.axes)
[docs] def _run(self, data): # pylint: disable=W0221 if isinstance(self.axes, (tuple, list)): sq = data for a in reversed(self.axes): sq = numpy.squeeze(sq, axis=a) else: sq = numpy.squeeze(data, axis=self.axes) return (sq, )
[docs] def _infer_shapes(self, x): # pylint: disable=W0221 return (x.squeeze(axis=self.axes), )
[docs]class Squeeze_11(Squeeze_1): pass
[docs]class Squeeze_13(OpRun): atts = {'keepdims': 1}
[docs] def __init__(self, onnx_node, desc=None, **options): OpRun.__init__(self, onnx_node, desc=desc, expected_attributes=Squeeze_13.atts, **options) self.axes = None
[docs] def _run(self, data, axes=None): # pylint: disable=W0221 if axes is not None: sq = data for a in reversed(sorted(axes)): sq = numpy.squeeze(sq, axis=a) else: sq = numpy.squeeze(data) return (sq, )
[docs] def _infer_shapes(self, x, axes=None): # pylint: disable=W0221 return (ShapeObject(None, dtype=x.dtype), )
if onnx_opset_version() >= 13: Squeeze = Squeeze_13 elif onnx_opset_version() >= 11: Squeeze = Squeeze_11 else: Squeeze = Squeeze_1