Source code for mlprodict.onnxrt.ops_cpu.op_argmax

# -*- encoding: utf-8 -*-
# pylint: disable=E0203,E1101,C0111
"""
Runtime operator.


:githublink:`%|py|7`
"""
import numpy
from onnx.defs import onnx_opset_version
from ._op import OpRunArg


[docs]def _argmax(data, axis=0, keepdims=True): result = numpy.argmax(data, axis=axis) if keepdims and len(result.shape) < len(data.shape): result = numpy.expand_dims(result, axis) return result.astype(numpy.int64)
[docs]def _argmax_use_numpy_select_last_index( data, axis=0, keepdims=True): data = numpy.flip(data, axis) result = numpy.argmax(data, axis=axis) result = data.shape[axis] - result - 1 if keepdims and len(result.shape) < len(data.shape): result = numpy.expand_dims(result, axis) return result.astype(numpy.int64)
[docs]class _ArgMax(OpRunArg): """ Base class for runtime for operator `ArgMax <https://github.com/onnx/onnx/blob/master/docs/ Operators.md#ArgMax>`_. :githublink:`%|py|34` """
[docs] def __init__(self, onnx_node, desc=None, expected_attributes=None, **options): OpRunArg.__init__(self, onnx_node, desc=desc, expected_attributes=expected_attributes, **options)
[docs] def _run(self, data): # pylint: disable=W0221 return (_argmax(data, axis=self.axis, keepdims=self.keepdims), )
[docs] def to_python(self, inputs): return ('import numpy\nfrom mlprodict.onnxrt.ops_cpu.op_argmax import _argmax', 'return _argmax(%s, axis=axis, keepdims=keepdims)' % inputs[0])
[docs]class ArgMax_11(_ArgMax): atts = {'axis': 0, 'keepdims': 1}
[docs] def __init__(self, onnx_node, desc=None, **options): _ArgMax.__init__(self, onnx_node, desc=desc, expected_attributes=ArgMax_11.atts, **options)
[docs]class ArgMax_12(_ArgMax): atts = {'axis': 0, 'keepdims': 1, 'select_last_index': 0}
[docs] def __init__(self, onnx_node, desc=None, **options): _ArgMax.__init__(self, onnx_node, desc=desc, expected_attributes=ArgMax_12.atts, **options)
[docs] def _run(self, data): # pylint: disable=W0221 if self.select_last_index == 0: return _ArgMax._run(self, data) return (_argmax_use_numpy_select_last_index( data, axis=self.axis, keepdims=self.keepdims), )
[docs] def to_python(self, inputs): lines = [ "if select_last_index == 0:", " return _argmax({0}, axis=axis, keepdims=keepdims)", "return _argmax_use_numpy_select_last_index(", " {0}, axis=axis, keepdims=keepdims)"] return ('import numpy\nfrom mlprodict.onnxrt.ops_cpu.op_argmax import _argmax, _argmax_use_numpy_select_last_index', "\n".join(lines).format(inputs[0]))
if onnx_opset_version() >= 12: ArgMax = ArgMax_12 else: ArgMax = ArgMax_11 # pragma: no cover