Source code for mlprodict.onnxrt.ops_cpu.op_split
# -*- encoding: utf-8 -*-
# pylint: disable=E0203,E1101,C0111
"""
Runtime operator.
:githublink:`%|py|7`
"""
from onnx.defs import onnx_opset_version
from ._op import OpRun
from ..shape_object import DimensionObject, ShapeObject
[docs]class CommonSplit(OpRun):
"""
Runtime for operator *Split*.
:githublink:`%|py|15`
"""
[docs] def __init__(self, onnx_node, desc=None,
expected_attributes=None, **options):
if 'split' not in options:
options['split'] = None
OpRun.__init__(self, onnx_node, desc=desc,
expected_attributes=expected_attributes,
**options)
self.nb_outputs = len(onnx_node.output)
def common_run(self, mat, split): # pylint: disable=W0221
if split is None:
div = mat.shape[self.axis] // self.nb_outputs
split = [div] * self.nb_outputs
split[-1] += mat.shape[self.axis] - sum(split)
sli = [slice(0, s) for s in mat.shape]
res = []
pos = 0
for spl in split:
sli[self.axis] = slice(pos, pos + spl)
pos += spl
res.append(mat[tuple(sli)])
return tuple(res)
def common_infer_shapes(self, data, split): # pylint: disable=W0221
if split is None:
return tuple([ShapeObject(None, dtype=data.dtype)
for o in range(self.nb_outputs)])
res = []
pos = 0
for spl in split:
shape = data.copy()
shape[self.axis] = DimensionObject(spl)
pos += spl
res.append(shape)
return tuple(res)
[docs]class Split_2(CommonSplit):
"""
Runtime for operator *Split*.
:githublink:`%|py|57`
"""
atts = {'axis': 0, 'split': None}
[docs] def __init__(self, onnx_node, desc=None, **options):
CommonSplit.__init__(self, onnx_node, desc=desc,
expected_attributes=Split_2.atts, **options)
[docs] def _run(self, mat): # pylint: disable=W0221
return self.common_run(mat, self.split)
[docs] def _infer_shapes(self, data): # pylint: disable=W0221
return self.common_infer_shapes(data, self.split)
[docs]class Split_11(Split_2):
"""
Runtime for operator *Split*.
:githublink:`%|py|75`
"""
pass
[docs]class Split_13(CommonSplit):
"""
Runtime for operator *Split*.
:githublink:`%|py|82`
"""
atts = {'axis': 0}
[docs] def __init__(self, onnx_node, desc=None, **options):
CommonSplit.__init__(self, onnx_node, desc=desc,
expected_attributes=Split_13.atts, **options)
[docs] def _run(self, mat, split=None): # pylint: disable=W0221
return self.common_run(mat, split)
[docs] def _infer_shapes(self, data, split=None): # pylint: disable=W0221
return tuple([ShapeObject(None, dtype=data.dtype)
for o in range(self.nb_outputs)])
if onnx_opset_version() >= 13:
Split = Split_13
elif onnx_opset_version() >= 11:
Split = Split_11
else:
Split = Split_2