Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRun 

9from ..shape_object import ShapeObject 

10 

11 

12def gather_numpy_2(self, dim, index): 

13 res = [] 

14 for a, b in zip(self, index): 

15 res.append(a[b[0]]) 

16 res = numpy.array( 

17 res, dtype=self.dtype).reshape(index.shape) 

18 return res 

19 

20 

21def gather_numpy(self, dim, index): 

22 """ 

23 Gathers values along an axis specified by dim. 

24 For a 3-D tensor the output is specified by: 

25 

26 :: 

27 

28 out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 

29 out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 

30 out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2 

31 

32 :param dim: The axis along which to index 

33 :param index: A tensor of indices of elements to gather 

34 :return: tensor of gathered values 

35 

36 See `How to do scatter and gather operations in numpy? 

37 <https://stackoverflow.com/questions/46065873/ 

38 how-to-do-scatter-and-gather-operations-in-numpy/46204790#46204790>`_ 

39 """ 

40 idx_xsection_shape = index.shape[:dim] + index.shape[dim + 1:] 

41 self_xsection_shape = self.shape[:dim] + self.shape[dim + 1:] 

42 if idx_xsection_shape != self_xsection_shape: 

43 raise ValueError( # pragma: no cover 

44 "Except for dimension {}, all dimensions of " 

45 "index and self should be the same size".format(dim)) 

46 data_swaped = numpy.swapaxes(self, 0, dim) 

47 index_swaped = numpy.swapaxes(index, 0, dim) 

48 try: 

49 gathered = numpy.choose(index_swaped, data_swaped) 

50 except ValueError as e: 

51 if len(index_swaped.shape) == 2 and len(data_swaped.shape) == 2: 

52 return gather_numpy_2(self, dim, index) 

53 raise e # pragma: no cover 

54 

55 return numpy.swapaxes(gathered, 0, dim) 

56 

57 

58class GatherElements(OpRun): 

59 

60 atts = {'axis': 0} 

61 

62 def __init__(self, onnx_node, desc=None, **options): 

63 OpRun.__init__(self, onnx_node, desc=desc, 

64 expected_attributes=GatherElements.atts, 

65 **options) 

66 

67 def _run(self, data, indices): # pylint: disable=W0221 

68 if indices.size == 0: 

69 return (numpy.empty((0, ), dtype=data.dtype), ) 

70 y = gather_numpy(data, self.axis, indices) 

71 return (y, ) 

72 

73 def _infer_shapes(self, data, indices): # pylint: disable=W0221 

74 return (ShapeObject(None, data.dtype), ) 

75 

76 def _infer_types(self, data, indices): # pylint: disable=W0221 

77 return (data, ) 

78 

79 def _infer_sizes(self, *args): # pylint: disable=W0221 

80 res = self.run(*args) 

81 return (dict(temp=sum(a.size * a.dtype.itemsize for a in args)), ) + res 

82 

83 def to_python(self, inputs): 

84 lines = ['data_swaped = numpy.swapaxes(%s, 0, axis)' % inputs[0], 

85 'index_swaped = numpy.swapaxes(%s, 0, axis)' % inputs[1], 

86 "gathered = numpy.choose(index_swaped, data_swaped, mode='wrap')", 

87 'return numpy.swapaxes(gathered, 0, axis)'] 

88 return "import numpy", "\n".join(lines)