Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from onnx.defs import onnx_opset_version 

9from ._op import OpRun 

10from ._op_onnx_numpy import ( # pylint: disable=E0611,E0401 

11 topk_element_min_double, topk_element_max_double, topk_element_fetch_double, 

12 topk_element_min_float, topk_element_max_float, topk_element_fetch_float, 

13 topk_element_min_int64, topk_element_max_int64, topk_element_fetch_int64) 

14 

15 

16def topk_sorted_implementation(X, k, axis, largest): 

17 """ 

18 Retrieves the top-k elements. 

19 

20 @param X data 

21 @param k k in top-k 

22 @param axis axis chosen to select the top-k elements 

23 @param largest largest (1) or smallest (0) 

24 @return top-k values, top-k indices 

25 

26 See function `_kneighbors_reduce_func 

27 <https://github.com/scikit-learn/scikit-learn/tree/master/ 

28 sklearn/neighbors/base.py#L304>`_. 

29 """ 

30 if isinstance(k, numpy.ndarray): 

31 if k.size != 1: 

32 raise RuntimeError( # pragma: no cover 

33 "k must be an integer not %r." % k) 

34 k = k[0] 

35 if len(X.shape) == 2 and axis == 1: 

36 sample_range = numpy.arange(X.shape[0])[:, None] 

37 if largest == 0: 

38 sorted_indices = numpy.argpartition(X, axis=axis, kth=k - 1) 

39 sorted_indices = sorted_indices[:, :k] 

40 # argpartition doesn't guarantee sorted order, so we sort again 

41 sorted_indices = sorted_indices[ 

42 sample_range, numpy.argsort(X[sample_range, sorted_indices])] 

43 else: 

44 sorted_indices = numpy.argpartition(-X, axis=axis, kth=k - 1) 

45 sorted_indices = sorted_indices[:, :k] 

46 # argpartition doesn't guarantee sorted order, so we sort again 

47 sorted_indices = sorted_indices[ 

48 sample_range, numpy.argsort(-X[sample_range, sorted_indices])] 

49 sorted_distances = X[sample_range, sorted_indices] 

50 return sorted_distances, sorted_indices 

51 

52 sorted_indices = numpy.argsort(X, axis=axis) 

53 sorted_values = numpy.sort(X, axis=axis) 

54 if largest: 

55 sorted_indices = numpy.flip(sorted_indices, axis=axis) 

56 sorted_values = numpy.flip(sorted_values, axis=axis) 

57 ark = numpy.arange(k) 

58 topk_sorted_indices = numpy.take(sorted_indices, ark, axis=axis) 

59 topk_sorted_values = numpy.take(sorted_values, ark, axis=axis) 

60 return topk_sorted_values, topk_sorted_indices 

61 

62 

63def topk_sorted_implementation_cpp(X, k, axis, largest, th_para=50): 

64 """ 

65 Retrieves the top-k elements using a C++ 

66 implementation when the axis is the last dimension, 

67 otherwise, it falls back to 

68 @see fn topk_sorted_implementation. 

69 

70 @param X data 

71 @param k k in top-k 

72 @param axis axis chosen to select the top-k elements 

73 @param largest largest (1) or smallest (0) 

74 @param th_para threshold for parallelisation 

75 @return top-k values, top-k indices 

76 """ 

77 if isinstance(k, numpy.ndarray): 

78 if k.size != 1: 

79 raise RuntimeError( # pragma: no cover 

80 "k must be an integer not %r." % k) 

81 if axis != len(X.shape) - 1: 

82 if k == 0: 

83 return numpy.empty((0,), dtype=numpy.int64) 

84 return topk_sorted_implementation(X, k, axis, largest) 

85 if X.dtype == numpy.float64: 

86 if k == 0: 

87 return numpy.empty((0,), dtype=X.dtype), numpy.empty((0,), dtype=numpy.int64) 

88 if largest: 

89 topk_sorted_indices = topk_element_max_double(X, k, True, th_para) 

90 else: 

91 topk_sorted_indices = topk_element_min_double(X, k, True, th_para) 

92 topk_sorted_values = topk_element_fetch_double(X, topk_sorted_indices) 

93 elif X.dtype == numpy.float32: 

94 if k == 0: 

95 return numpy.empty((0,), dtype=X.dtype), numpy.empty((0,), dtype=numpy.int64) 

96 if largest: 

97 topk_sorted_indices = topk_element_max_float(X, k, True, th_para) 

98 else: 

99 topk_sorted_indices = topk_element_min_float(X, k, True, th_para) 

100 topk_sorted_values = topk_element_fetch_float(X, topk_sorted_indices) 

101 elif X.dtype == numpy.int64: 

102 if k == 0: 

103 return numpy.empty((0,), dtype=X.dtype), numpy.empty((0,), dtype=numpy.int64) 

104 if largest: 

105 topk_sorted_indices = topk_element_max_int64(X, k, True, th_para) 

106 else: 

107 topk_sorted_indices = topk_element_min_int64(X, k, True, th_para) 

108 topk_sorted_values = topk_element_fetch_int64(X, topk_sorted_indices) 

109 else: 

110 if k == 0: 

111 return numpy.empty((0,), dtype=numpy.int64) 

112 return topk_sorted_implementation(X, k, axis, largest) 

113 return topk_sorted_values, topk_sorted_indices 

114 

115 

116class _CommonTopK(OpRun): 

117 """ 

118 Ths class hides a parameter used as a threshold above 

119 which the parallelisation is started: ``th_para``. 

120 """ 

121 

122 atts = {'axis': -1} 

123 

124 def __init__(self, *args, **options): 

125 OpRun.__init__(self, *args, **options) 

126 self.th_para = 50 

127 

128 def _common_run(self, data, ink, largest=1): # pylint: disable=W0221 

129 """ 

130 Runtime for operator *TopK*. 

131 The implementation is not the most efficient 

132 as it sorts everything then extracts the top *k* 

133 values. 

134 

135 .. warning:: 

136 ONNX specifications may be imprecise in case of negative value 

137 for axis. The implementation follows what :epkg:`onnxruntime` 

138 does in `top_k.cc 

139 <https://github.com/Microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cpu/math/top_k.cc#L63>`_. 

140 """ 

141 k = ink[0] 

142 axis = self.axis if self.axis >= 0 else (self.axis + len(data.shape)) 

143 sort, sorti = topk_sorted_implementation_cpp( 

144 data, k, axis, largest, self.th_para) 

145 return (sort, sorti.astype(numpy.int64)) 

146 

147 def _infer_shapes(self, data, ink): # pylint: disable=W0221 

148 axis = self.axis if self.axis >= 0 else (self.axis + len(data)) 

149 sh = data.copy() 

150 pref = str(hex(id(self))[2:]) 

151 sh[axis] = "ntopk%s" % pref 

152 shi = sh.copy(dtype=numpy.int64) 

153 return (sh, shi) 

154 

155 def _infer_types(self, x, ink): # pylint: disable=E0202,W0221 

156 return (x, numpy.int64) 

157 

158 

159class TopK_1(_CommonTopK): 

160 

161 atts = {'axis': -1, 'k': None} 

162 

163 def __init__(self, onnx_node, desc=None, **options): 

164 _CommonTopK.__init__(self, onnx_node, desc=desc, 

165 expected_attributes=TopK_10.atts, 

166 **options) 

167 

168 def _run(self, data): # pylint: disable=W0221 

169 """ 

170 Runtime for operator *TopK*. 

171 The implementation is not the most efficient 

172 as it sorts everything then extracts the top *k* 

173 values. 

174 

175 .. warning:: 

176 ONNX specifications may be imprecise in case of negative value 

177 for axis. The implementation follows what :epkg:`onnxruntime` 

178 does in `top_k.cc 

179 <https://github.com/Microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cpu/math/top_k.cc#L63>`_. 

180 """ 

181 return _CommonTopK._common_run(self, data, [self.k]) 

182 

183 def _infer_shapes(self, data): # pylint: disable=W0221 

184 return _CommonTopK._infer_shapes(self, data, [self.k]) 

185 

186 def _infer_types(self, data): # pylint: disable=W0221 

187 return (data, ) 

188 

189 def _infer_sizes(self, *args): # pylint: disable=W0221 

190 res = self.run(*args) 

191 x = args[0] 

192 return (dict(temp=x.dtype.itemsize * self.k * 2), ) + res 

193 

194 

195class TopK_10(_CommonTopK): 

196 

197 atts = {'axis': -1} 

198 

199 def __init__(self, onnx_node, desc=None, **options): 

200 _CommonTopK.__init__(self, onnx_node, desc=desc, 

201 expected_attributes=TopK_10.atts, 

202 **options) 

203 

204 def _run(self, data, ink): # pylint: disable=W0221 

205 """ 

206 Runtime for operator *TopK*. 

207 The implementation is not the most efficient 

208 as it sorts everything then extracts the top *k* 

209 values. 

210 

211 .. warning:: 

212 ONNX specifications may be imprecise in case of negative value 

213 for axis. The implementation follows what :epkg:`onnxruntime` 

214 does in `top_k.cc 

215 <https://github.com/Microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cpu/math/top_k.cc#L63>`_. 

216 """ 

217 return _CommonTopK._common_run(self, data, ink) 

218 

219 def _infer_sizes(self, data, ink): # pylint: disable=W0221 

220 res = self.run(data, ink) 

221 return (dict(temp=data.dtype.itemsize * ink[0] * 2), ) + res 

222 

223 

224class TopK_11(_CommonTopK): 

225 

226 atts = {'axis': -1, 'largest': 1, 'sorted': 1} 

227 

228 def __init__(self, onnx_node, desc=None, **options): 

229 _CommonTopK.__init__(self, onnx_node, desc=desc, 

230 expected_attributes=TopK_11.atts, 

231 **options) 

232 if self.sorted not in (True, 1): 

233 raise RuntimeError( # pragma: no cover 

234 "TopK does not implement anything for sorted=0.") 

235 

236 def _run(self, data, ink): # pylint: disable=W0221 

237 """ 

238 Runtime for operator *TopK*. 

239 The implementation is not the most efficient 

240 as it sorts everything then extracts the top *k* 

241 values. 

242 

243 .. warning:: 

244 ONNX specifications may be imprecise in case of negative value 

245 for axis. The implementation follows what :epkg:`onnxruntime` 

246 does in `top_k.cc 

247 <https://github.com/Microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cpu/math/top_k.cc#L63>`_. 

248 """ 

249 return _CommonTopK._common_run(self, data, ink, self.largest) 

250 

251 def _infer_sizes(self, data, ink): # pylint: disable=W0221 

252 res = self.run(data, ink) 

253 return (dict(temp=data.dtype.itemsize * ink[0] * 2), ) + res 

254 

255 

256if onnx_opset_version() >= 11: 

257 TopK = TopK_11 

258elif onnx_opset_version() >= 10: # pragma: no cover 

259 TopK = TopK_10 

260else: # pragma: no cover 

261 TopK = TopK_1