Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7from collections import OrderedDict 

8import numpy 

9from onnx.defs import onnx_opset_version 

10from ._op import OpRunUnaryNum 

11 

12 

13class Clip_6(OpRunUnaryNum): 

14 

15 atts = {'min': -3.4028234663852886e+38, 

16 'max': 3.4028234663852886e+38} 

17 

18 def __init__(self, onnx_node, desc=None, **options): 

19 OpRunUnaryNum.__init__(self, onnx_node, desc=desc, 

20 expected_attributes=Clip_6.atts, 

21 **options) 

22 

23 def _run(self, data): # pylint: disable=W0221 

24 if self.inplaces.get(0, False): 

25 return self._run_inplace(data) 

26 res = numpy.clip(data, self.min, self.max) 

27 return (res, ) if res.dtype == data.dtype else (res.astype(data.dtype), ) 

28 

29 def _run_inplace(self, data): 

30 return (numpy.clip(data, self.min, self.max, out=data), ) 

31 

32 def to_python(self, inputs): 

33 return ("import numpy", 

34 "return numpy.clip(%s, min_, max_)" % inputs[0]) 

35 

36 

37class Clip_11(OpRunUnaryNum): 

38 

39 version_higher_than = 11 

40 mandatory_inputs = ['X'] 

41 optional_inputs = OrderedDict([ 

42 ('min', -3.4028234663852886e+38), 

43 ('max', 3.4028234663852886e+38) 

44 ]) 

45 

46 def __init__(self, onnx_node, desc=None, **options): 

47 OpRunUnaryNum.__init__(self, onnx_node, desc=desc, 

48 **options) 

49 

50 def run(self, x, *minmax): # pylint: disable=E0202,W0221 

51 """ 

52 Calls method ``_run``. 

53 """ 

54 try: 

55 res = self._run(x, *minmax) 

56 except TypeError as e: 

57 raise TypeError("Issues with types {} (binary operator {}).".format( 

58 ", ".join(str(type(_)) for _ in [x]), 

59 self.__class__.__name__)) from e 

60 return res 

61 

62 def _run(self, data, *minmax): # pylint: disable=W0221 

63 if self.inplaces.get(0, False): 

64 return self._run_inplace(data, *minmax) 

65 le = len(minmax) 

66 amin = minmax[0] if le > 0 else None # -3.4028234663852886e+38 

67 amax = minmax[1] if le > 1 else None # 3.4028234663852886e+38 

68 res = numpy.clip(data, amin, amax) 

69 return (res, ) if res.dtype == data.dtype else (res.astype(data.dtype), ) 

70 

71 def _run_inplace(self, data, *minmax): # pylint: disable=W0221 

72 le = len(minmax) 

73 amin = minmax[0] if le > 0 else None # -3.4028234663852886e+38 

74 amax = minmax[1] if le > 1 else None # 3.4028234663852886e+38 

75 res = numpy.clip(data, amin, amax, out=data) 

76 return (res, ) 

77 

78 def infer_shapes(self, x, *minmax): # pylint: disable=E0202,W0221 

79 try: 

80 return self._infer_shapes(x) 

81 except TypeError as e: 

82 raise TypeError("Issues with types {} (operator {}).".format( 

83 x.dtype, self.__class__.__name__)) from e 

84 

85 def infer_types(self, x, *minmax): # pylint: disable=E0202,W0221 

86 try: 

87 return self._infer_types(x) 

88 except TypeError as e: 

89 raise TypeError("Issues with types {} (operator {}).".format( 

90 x.dtype, self.__class__.__name__)) from e 

91 

92 def to_python(self, inputs): 

93 return ("import numpy", 

94 "return numpy.clip(%s, min_, max_)" % inputs[0]) 

95 

96 

97if onnx_opset_version() >= 11: 

98 Clip = Clip_11 

99else: 

100 Clip = Clip_6