Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from onnx.defs import onnx_opset_version 

9from ..shape_object import ShapeObject 

10from ._op import OpRunUnaryNum, OpRun 

11 

12 

13class Squeeze_1(OpRunUnaryNum): 

14 

15 atts = {'axes': [], 'keepdims': 1} 

16 

17 def __init__(self, onnx_node, desc=None, **options): 

18 OpRunUnaryNum.__init__(self, onnx_node, desc=desc, 

19 expected_attributes=Squeeze_1.atts, 

20 **options) 

21 if isinstance(self.axes, numpy.ndarray): 

22 self.axes = tuple(self.axes) 

23 elif self.axes in [[], tuple()]: 

24 self.axes = None 

25 elif isinstance(self.axes, list): 

26 self.axes = tuple(self.axes) 

27 

28 def _run(self, data): # pylint: disable=W0221 

29 if isinstance(self.axes, (tuple, list)): 

30 sq = data 

31 for a in reversed(self.axes): 

32 sq = numpy.squeeze(sq, axis=a) 

33 else: 

34 sq = numpy.squeeze(data, axis=self.axes) 

35 return (sq, ) 

36 

37 def _infer_shapes(self, x): # pylint: disable=W0221 

38 return (x.squeeze(axis=self.axes), ) 

39 

40 def _infer_types(self, x): # pylint: disable=W0221 

41 return (x, ) 

42 

43 def _infer_sizes(self, *args, **kwargs): 

44 res = self.run(*args, **kwargs) 

45 return (dict(temp=0), ) + res 

46 

47 

48class Squeeze_11(Squeeze_1): 

49 pass 

50 

51 

52class Squeeze_13(OpRun): 

53 

54 atts = {'keepdims': 1} 

55 

56 def __init__(self, onnx_node, desc=None, **options): 

57 OpRun.__init__(self, onnx_node, desc=desc, 

58 expected_attributes=Squeeze_13.atts, 

59 **options) 

60 self.axes = None 

61 

62 def _run(self, data, axes=None): # pylint: disable=W0221 

63 if axes is not None: 

64 if hasattr(axes, '__iter__'): 

65 sq = numpy.squeeze(data, axis=tuple(axes)) 

66 else: 

67 sq = numpy.squeeze(data, axis=axes) 

68 else: 

69 sq = numpy.squeeze(data) 

70 return (sq, ) 

71 

72 def _infer_shapes(self, x, axes=None): # pylint: disable=W0221 

73 return (ShapeObject(None, dtype=x.dtype), ) 

74 

75 def _infer_types(self, x, axes=None): # pylint: disable=W0221 

76 return (x, ) 

77 

78 def _infer_sizes(self, *args, **kwargs): 

79 res = self.run(*args, **kwargs) 

80 return (dict(temp=0), ) + res 

81 

82 

83if onnx_opset_version() >= 13: 

84 Squeeze = Squeeze_13 

85elif onnx_opset_version() >= 11: 

86 Squeeze = Squeeze_11 

87else: 

88 Squeeze = Squeeze_1