Source code for mlprodict.onnxrt.ops_cpu.op_split

# -*- encoding: utf-8 -*-
# pylint: disable=E0203,E1101,C0111
"""
Runtime operator.


:githublink:`%|py|7`
"""
from onnx.defs import onnx_opset_version
from ._op import OpRun
from ..shape_object import DimensionObject, ShapeObject


[docs]class CommonSplit(OpRun): """ Runtime for operator *Split*. :githublink:`%|py|15` """
[docs] def __init__(self, onnx_node, desc=None, expected_attributes=None, **options): if 'split' not in options: options['split'] = None OpRun.__init__(self, onnx_node, desc=desc, expected_attributes=expected_attributes, **options) self.nb_outputs = len(onnx_node.output)
def common_run(self, mat, split): # pylint: disable=W0221 if split is None: div = mat.shape[self.axis] // self.nb_outputs split = [div] * self.nb_outputs split[-1] += mat.shape[self.axis] - sum(split) sli = [slice(0, s) for s in mat.shape] res = [] pos = 0 for spl in split: sli[self.axis] = slice(pos, pos + spl) pos += spl res.append(mat[tuple(sli)]) return tuple(res) def common_infer_shapes(self, data, split): # pylint: disable=W0221 if split is None: return tuple([ShapeObject(None, dtype=data.dtype) for o in range(self.nb_outputs)]) res = [] pos = 0 for spl in split: shape = data.copy() shape[self.axis] = DimensionObject(spl) pos += spl res.append(shape) return tuple(res)
[docs]class Split_2(CommonSplit): """ Runtime for operator *Split*. :githublink:`%|py|57` """ atts = {'axis': 0, 'split': None}
[docs] def __init__(self, onnx_node, desc=None, **options): CommonSplit.__init__(self, onnx_node, desc=desc, expected_attributes=Split_2.atts, **options)
[docs] def _run(self, mat): # pylint: disable=W0221 return self.common_run(mat, self.split)
[docs] def _infer_shapes(self, data): # pylint: disable=W0221 return self.common_infer_shapes(data, self.split)
[docs]class Split_11(Split_2): """ Runtime for operator *Split*. :githublink:`%|py|75` """ pass
[docs]class Split_13(CommonSplit): """ Runtime for operator *Split*. :githublink:`%|py|82` """ atts = {'axis': 0}
[docs] def __init__(self, onnx_node, desc=None, **options): CommonSplit.__init__(self, onnx_node, desc=desc, expected_attributes=Split_13.atts, **options)
[docs] def _run(self, mat, split=None): # pylint: disable=W0221 return self.common_run(mat, split)
[docs] def _infer_shapes(self, data, split=None): # pylint: disable=W0221 return tuple([ShapeObject(None, dtype=data.dtype) for o in range(self.nb_outputs)])
if onnx_opset_version() >= 13: Split = Split_13 elif onnx_opset_version() >= 11: Split = Split_11 else: Split = Split_2