Source code for mlprodict.onnxrt.ops_cpu.op_gather_elements

# -*- encoding: utf-8 -*-
# pylint: disable=E0203,E1101,C0111
"""
Runtime operator.


:githublink:`%|py|7`
"""
import numpy
from ._op import OpRun


def gather_numpy_2(self, dim, index):
    res = []
    for a, b in zip(self, index):
        res.append(a[b[0]])
    res = numpy.array(
        res, dtype=self.dtype).reshape(
            index.shape)
    return res


[docs]def gather_numpy(self, dim, index): """ Gathers values along an axis specified by dim. For a 3-D tensor the output is specified by: :: out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2 :param dim: The axis along which to index :param index: A tensor of indices of elements to gather :return: tensor of gathered values See `How to do scatter and gather operations in numpy? <https://stackoverflow.com/questions/46065873/ how-to-do-scatter-and-gather-operations-in-numpy/46204790#46204790>`_ :githublink:`%|py|39` """ idx_xsection_shape = index.shape[:dim] + index.shape[dim + 1:] self_xsection_shape = self.shape[:dim] + self.shape[dim + 1:] if idx_xsection_shape != self_xsection_shape: raise ValueError( # pragma: no cover "Except for dimension {}, all dimensions of " "index and self should be the same size".format(dim)) data_swaped = numpy.swapaxes(self, 0, dim) index_swaped = numpy.swapaxes(index, 0, dim) try: gathered = numpy.choose(index_swaped, data_swaped) except ValueError as e: if len(index_swaped.shape) == 2 and len(data_swaped.shape) == 2: return gather_numpy_2(self, dim, index) raise e # pragma: no cover return numpy.swapaxes(gathered, 0, dim)
[docs]class GatherElements(OpRun): atts = {'axis': 0}
[docs] def __init__(self, onnx_node, desc=None, **options): OpRun.__init__(self, onnx_node, desc=desc, expected_attributes=GatherElements.atts, **options)
[docs] def _run(self, data, indices): # pylint: disable=W0221 y = gather_numpy(data, self.axis, indices) return (y, )
[docs] def _infer_shapes(self, data, indices): # pylint: disable=W0221 return (indices, )
[docs] def to_python(self, inputs): lines = ['data_swaped = numpy.swapaxes(%s, 0, axis)' % inputs[0], 'index_swaped = numpy.swapaxes(%s, 0, axis)' % inputs[1], "gathered = numpy.choose(index_swaped, data_swaped, mode='wrap')", 'return numpy.swapaxes(gathered, 0, axis)'] return "import numpy", "\n".join(lines)