Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRun 

9 

10 

11class CumSum(OpRun): 

12 

13 atts = {'exclusive': 0, 'reverse': 0} 

14 python_inputs = ['x', 'axis=None'] 

15 

16 def __init__(self, onnx_node, desc=None, **options): 

17 OpRun.__init__(self, onnx_node, desc=desc, 

18 expected_attributes=CumSum.atts, 

19 **options) 

20 

21 def _run(self, x, *axis): # pylint: disable=W0221 

22 axis = None if len(axis) == 0 else axis[0] 

23 if axis is None: 

24 if self.reverse or self.exclusive: 

25 raise NotImplementedError( # pragma no cover 

26 'reverse=1 or exclusive=1 not implemented') 

27 if self.inplaces.get(0, False): 

28 return (numpy.cumsum(x, out=x), ) 

29 return (numpy.cumsum(x), ) 

30 if not isinstance(axis, (numpy.int32, numpy.int64)): 

31 if (len(axis.shape) > 1 or 

32 (len(axis.shape) > 0 and axis.shape[0] != 1)): 

33 raise RuntimeError( # pragma no cover 

34 "axis must be an array of one number not {} " 

35 "(shape {})".format(axis, axis.shape)) 

36 if len(axis.shape) > 0: 

37 axis = axis[0] # pylint: disable=E1136 

38 if self.reverse or self.exclusive: 

39 raise NotImplementedError( 

40 'reverse=1 or exclusive=1 not implemented') 

41 if self.inplaces.get(0, False): 

42 return (numpy.cumsum(x, axis=axis, out=x), ) 

43 return (numpy.cumsum(x, axis=axis), ) 

44 

45 def _infer_shapes(self, x, *axis): # pylint: disable=W0221 

46 return (x, ) 

47 

48 def _infer_types(self, x, *axis): # pylint: disable=W0221 

49 return (x, ) 

50 

51 def _infer_sizes(self, *args, **kwargs): 

52 res = self.run(*args, **kwargs) 

53 return (dict(temp=0), ) + res 

54 

55 def to_python(self, inputs): 

56 lines = ['if exclusive or reverse:', 

57 ' raise NotImplementedError("reverse=1 or exclusive=1 not implemente")', 

58 'if axis is None:', 

59 ' return numpy.cumsum(x)', 

60 'return numpy.cumsum(x, axis=axis[0])'] 

61 return 'import numpy', "\n".join(lines)