Source code for mlprodict.onnxrt.ops_cpu.op_argmin

# -*- encoding: utf-8 -*-
# pylint: disable=E0203,E1101,C0111
"""
Runtime operator.


:githublink:`%|py|7`
"""
import numpy
from onnx.defs import onnx_opset_version
from ._op import OpRunArg


[docs]def _argmin(data, axis=0, keepdims=True): result = numpy.argmin(data, axis=axis) if keepdims and len(result.shape) < len(data.shape): result = numpy.expand_dims(result, axis) return result.astype(numpy.int64)
[docs]def _argmin_use_numpy_select_last_index( data, axis=0, keepdims=True): data = numpy.flip(data, axis) result = numpy.argmin(data, axis=axis) result = data.shape[axis] - result - 1 if keepdims: result = numpy.expand_dims(result, axis) return result.astype(numpy.int64)
[docs]class _ArgMin(OpRunArg): """ Base class for runtime for operator `ArgMin <https://github.com/onnx/onnx/blob/master/docs/ Operators.md#ArgMin>`_. :githublink:`%|py|34` """
[docs] def __init__(self, onnx_node, desc=None, expected_attributes=None, **options): OpRunArg.__init__(self, onnx_node, desc=desc, expected_attributes=expected_attributes, **options)
[docs] def _run(self, data): # pylint: disable=W0221 return (_argmin(data, axis=self.axis, keepdims=self.keepdims), )
[docs]class ArgMin_11(_ArgMin): atts = {'axis': 0, 'keepdims': 1}
[docs] def __init__(self, onnx_node, desc=None, **options): _ArgMin.__init__(self, onnx_node, desc=desc, expected_attributes=ArgMin_11.atts, **options)
[docs] def to_python(self, inputs): return ('import numpy\nfrom mlprodict.onnxrt.ops_cpu.op_argmin import _argmin', 'return _argmin(%s, axis=axis, keepdims=keepdims)' % inputs[0])
[docs]class ArgMin_12(_ArgMin): atts = {'axis': 0, 'keepdims': 1, 'select_last_index': 0}
[docs] def __init__(self, onnx_node, desc=None, **options): _ArgMin.__init__(self, onnx_node, desc=desc, expected_attributes=ArgMin_12.atts, **options)
[docs] def _run(self, data): # pylint: disable=W0221 if self.select_last_index == 0: return _ArgMin._run(self, data) return (_argmin_use_numpy_select_last_index( data, axis=self.axis, keepdims=self.keepdims), )
[docs] def to_python(self, inputs): lines = [ "if select_last_index == 0:", " return _argmin({0}, axis=axis, keepdims=keepdims)", "return _argmin_use_numpy_select_last_index(", " {0}, axis=axis, keepdims=keepdims)"] return ('import numpy\nfrom mlprodict.onnxrt.ops_cpu.op_argmin import _argmin, _argmin_use_numpy_select_last_index', "\n".join(lines).format(inputs[0]))
if onnx_opset_version() >= 12: ArgMin = ArgMin_12 else: ArgMin = ArgMin_11 # pragma: no cover