# -*- encoding: utf-8 -*-
# pylint: disable=E0203,E1101,C0111
"""
Runtime operator.
:githublink:`%|py|7`
"""
import numpy
from onnx.defs import onnx_opset_version
from ._op import OpRun
from ._op_onnx_numpy import ( # pylint: disable=E0611,E0401
topk_element_min_double, topk_element_max_double, topk_element_fetch_double,
topk_element_min_float, topk_element_max_float, topk_element_fetch_float,
topk_element_min_int64, topk_element_max_int64, topk_element_fetch_int64)
[docs]def topk_sorted_implementation(X, k, axis, largest):
"""
Retrieves the top-k elements.
:param X: data
:param k: k in top-k
:param axis: axis chosen to select the top-k elements
:param largest: largest (1) or smallest (0)
:return: top-k values, top-k indices
See function `_kneighbors_reduce_func
<https://github.com/scikit-learn/scikit-learn/tree/master/
sklearn/neighbors/base.py#L304>`_.
:githublink:`%|py|29`
"""
if len(X.shape) == 2 and axis == 1:
sample_range = numpy.arange(X.shape[0])[:, None]
if largest == 0:
sorted_indices = numpy.argpartition(X, axis=axis, kth=k - 1)
sorted_indices = sorted_indices[:, :k]
# argpartition doesn't guarantee sorted order, so we sort again
sorted_indices = sorted_indices[
sample_range, numpy.argsort(X[sample_range, sorted_indices])]
else:
sorted_indices = numpy.argpartition(-X, axis=axis, kth=k - 1)
sorted_indices = sorted_indices[:, :k]
# argpartition doesn't guarantee sorted order, so we sort again
sorted_indices = sorted_indices[
sample_range, numpy.argsort(-X[sample_range, sorted_indices])]
sorted_distances = X[sample_range, sorted_indices]
return sorted_distances, sorted_indices
sorted_indices = numpy.argsort(X, axis=axis)
sorted_values = numpy.sort(X, axis=axis)
if largest:
sorted_indices = numpy.flip(sorted_indices, axis=axis)
sorted_values = numpy.flip(sorted_values, axis=axis)
ark = numpy.arange(k)
topk_sorted_indices = numpy.take(sorted_indices, ark, axis=axis)
topk_sorted_values = numpy.take(sorted_values, ark, axis=axis)
return topk_sorted_values, topk_sorted_indices
[docs]def topk_sorted_implementation_cpp(X, k, axis, largest, th_para=50):
"""
Retrieves the top-k elements using a C++
implementation when the axis is the last dimension,
otherwise, it falls back to
:func:`topk_sorted_implementation <mlprodict.onnxrt.ops_cpu.op_topk.topk_sorted_implementation>`.
:param X: data
:param k: k in top-k
:param axis: axis chosen to select the top-k elements
:param largest: largest (1) or smallest (0)
:param th_para: threshold for parallelisation
:return: top-k values, top-k indices
:githublink:`%|py|71`
"""
if axis != len(X.shape) - 1:
return topk_sorted_implementation(X, k, axis, largest)
if X.dtype == numpy.float64:
if largest:
topk_sorted_indices = topk_element_max_double(X, k, True, th_para)
else:
topk_sorted_indices = topk_element_min_double(X, k, True, th_para)
topk_sorted_values = topk_element_fetch_double(X, topk_sorted_indices)
elif X.dtype == numpy.float32:
if largest:
topk_sorted_indices = topk_element_max_float(X, k, True, th_para)
else:
topk_sorted_indices = topk_element_min_float(X, k, True, th_para)
topk_sorted_values = topk_element_fetch_float(X, topk_sorted_indices)
elif X.dtype == numpy.int64:
if largest:
topk_sorted_indices = topk_element_max_int64(X, k, True, th_para)
else:
topk_sorted_indices = topk_element_min_int64(X, k, True, th_para)
topk_sorted_values = topk_element_fetch_int64(X, topk_sorted_indices)
else:
return topk_sorted_implementation(X, k, axis, largest)
return topk_sorted_values, topk_sorted_indices
[docs]class _CommonTopK(OpRun):
"""
Ths class hides a parameter used as a threshold above
which the parallelisation is started: ``th_para``.
:githublink:`%|py|101`
"""
atts = {'axis': -1}
[docs] def __init__(self, *args, **options):
OpRun.__init__(self, *args, **options)
self.th_para = 50
[docs] def _common_run(self, data, ink, largest=1): # pylint: disable=W0221
"""
Runtime for operator *TopK*.
The implementation is not the most efficient
as it sorts everything then extracts the top *k*
values.
.. warning::
ONNX specifications may be imprecise in case of negative value
for axis. The implementation follows what :epkg:`onnxruntime`
does in `top_k.cc
<https://github.com/Microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cpu/math/top_k.cc#L63>`_.
:githublink:`%|py|121`
"""
k = ink[0]
axis = self.axis if self.axis >= 0 else (self.axis + len(data.shape))
sort, sorti = topk_sorted_implementation_cpp(
data, k, axis, largest, self.th_para)
return (sort, sorti.astype(numpy.int64))
[docs] def _infer_shapes(self, data, ink): # pylint: disable=W0221
axis = self.axis if self.axis >= 0 else (self.axis + len(data))
sh = data.copy()
pref = str(hex(id(self))[2:])
sh[axis] = "ntopk%s" % pref
shi = sh.copy(dtype=numpy.int64)
return (sh, shi)
[docs]class TopK_1(_CommonTopK):
atts = {'axis': -1, 'k': None}
[docs] def __init__(self, onnx_node, desc=None, **options):
_CommonTopK.__init__(self, onnx_node, desc=desc,
expected_attributes=TopK_10.atts,
**options)
[docs] def _run(self, data): # pylint: disable=W0221
"""
Runtime for operator *TopK*.
The implementation is not the most efficient
as it sorts everything then extracts the top *k*
values.
.. warning::
ONNX specifications may be imprecise in case of negative value
for axis. The implementation follows what :epkg:`onnxruntime`
does in `top_k.cc
<https://github.com/Microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cpu/math/top_k.cc#L63>`_.
:githublink:`%|py|158`
"""
return _CommonTopK._common_run(self, data, [self.k])
[docs] def _infer_shapes(self, data): # pylint: disable=W0221
return _CommonTopK._infer_shapes(self, data, [self.k])
[docs]class TopK_10(_CommonTopK):
atts = {'axis': -1}
[docs] def __init__(self, onnx_node, desc=None, **options):
_CommonTopK.__init__(self, onnx_node, desc=desc,
expected_attributes=TopK_10.atts,
**options)
[docs] def _run(self, data, ink): # pylint: disable=W0221
"""
Runtime for operator *TopK*.
The implementation is not the most efficient
as it sorts everything then extracts the top *k*
values.
.. warning::
ONNX specifications may be imprecise in case of negative value
for axis. The implementation follows what :epkg:`onnxruntime`
does in `top_k.cc
<https://github.com/Microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cpu/math/top_k.cc#L63>`_.
:githublink:`%|py|186`
"""
return _CommonTopK._common_run(self, data, ink)
[docs]class TopK_11(_CommonTopK):
atts = {'axis': -1, 'largest': 1, 'sorted': 1}
[docs] def __init__(self, onnx_node, desc=None, **options):
_CommonTopK.__init__(self, onnx_node, desc=desc,
expected_attributes=TopK_11.atts,
**options)
if self.sorted not in (True, 1):
raise RuntimeError( # pragma: no cover
"TopK does not implement anything for sorted=0.")
[docs] def _run(self, data, ink): # pylint: disable=W0221
"""
Runtime for operator *TopK*.
The implementation is not the most efficient
as it sorts everything then extracts the top *k*
values.
.. warning::
ONNX specifications may be imprecise in case of negative value
for axis. The implementation follows what :epkg:`onnxruntime`
does in `top_k.cc
<https://github.com/Microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cpu/math/top_k.cc#L63>`_.
:githublink:`%|py|214`
"""
return _CommonTopK._common_run(self, data, ink, self.largest)
if onnx_opset_version() >= 11:
TopK = TopK_11
elif onnx_opset_version() >= 10: # pragma: no cover
TopK = TopK_10
else: # pragma: no cover
TopK = TopK_1