Source code for mlprodict.onnxrt.ops_cpu.op_slice

# -*- encoding: utf-8 -*-
# pylint: disable=E0203,E1101,C0111
"""
Runtime operator.


:githublink:`%|py|7`
"""
from onnx.defs import onnx_opset_version
from ..shape_object import ShapeObject
from ._op import OpRun


[docs]class SliceCommon(OpRun):
[docs] def __init__(self, onnx_node, desc=None, **options): OpRun.__init__(self, onnx_node, desc=desc, **options)
[docs] def _run(self, data, starts, ends, axes=None, steps=None): # pylint: disable=W0221 if axes is None: if steps is None: slices = [slice(s, e) for s, e in zip(starts, ends)] else: slices = [slice(s, e, d) for s, e, d in zip(starts, ends, steps)] else: if steps is None: slices = [slice(0, a) for a in data.shape] for s, e, a in zip(starts, ends, axes): slices[a] = slice(s, e) else: slices = [slice(0, a) for a in data.shape] for s, e, a, d in zip(starts, ends, axes, steps): slices[a] = slice(s, e, d) return (data[tuple(slices)], )
[docs] def _infer_shapes(self, data, starts, ends, axes=None, steps=None): # pylint: disable=W0221 pref = str(hex(id(self))[2:]) shape = ["nslice%s_%d" % (pref, i) for i in range(len(data.shape))] return (ShapeObject(shape, data.dtype), )
[docs]class Slice_10(SliceCommon):
[docs] def __init__(self, onnx_node, desc=None, **options): SliceCommon.__init__(self, onnx_node, desc=desc, **options)
[docs]class Slice_1(SliceCommon): atts = {'starts': [], 'ends': [], 'axes': []}
[docs] def __init__(self, onnx_node, desc=None, **options): SliceCommon.__init__(self, onnx_node, desc=desc, expected_attributes=Slice_1.atts, **options) for f in ['starts', 'ends', 'steps', 'axes']: if not hasattr(self, f): continue if getattr(self, f) is not None and len(getattr(self, f)) == 0: setattr(self, f, None)
[docs] def _run(self, data): # pylint: disable=W0221 return SliceCommon._run( self, data, self.starts, self.ends, self.axes)
[docs] def _infer_shapes(self, data): # pylint: disable=W0221 return SliceCommon._infer_shapes( self, data, self.starts, self.ends, self.axes)
if onnx_opset_version() >= 10: Slice = Slice_10 else: Slice = Slice_1 # pragma: no cover