Source code for mlprodict.onnxrt.ops_cpu.op_gather

# -*- encoding: utf-8 -*-
# pylint: disable=E0203,E1101,C0111
"""
Runtime operator.


:githublink:`%|py|7`
"""
import numpy
from ._op import OpRun
from ..shape_object import ShapeObject
from .op_gather_ import (  # pylint: disable=E0611,E0401
    GatherFloat, GatherDouble, GatherInt64)


[docs]class Gather(OpRun): atts = {'axis': 0}
[docs] def __init__(self, onnx_node, desc=None, **options): OpRun.__init__(self, onnx_node, desc=desc, expected_attributes=Gather.atts, **options) self.rt_ = { 'float32': GatherFloat(self.axis), 'float64': GatherDouble(self.axis), 'int64': GatherInt64(self.axis)}
[docs] def _run(self, x, indices): # pylint: disable=W0221 if not x.flags['C_CONTIGUOUS']: x = numpy.ascontiguousarray(x) if not indices.flags['C_CONTIGUOUS']: indices = indices.ascontiguousarray() try: return (self.rt_[str(x.dtype)].compute(x, indices), ) except KeyError: return (numpy.take(x, indices, axis=self.axis), )
[docs] def _infer_shapes(self, x, indices): # pylint: disable=E0202,W0221 """ Returns the same shape by default. :githublink:`%|py|40` """ return (ShapeObject.gather_shape(x, indices, self.axis), )