Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7 

8import numpy 

9from numpy.random import RandomState 

10from onnx.defs import onnx_opset_version 

11from ._op import OpRun 

12 

13 

14def _dropout(X, drop_probability=0.5, seed=0, 

15 training_mode=False, return_mask=False): 

16 if drop_probability == 0 or not training_mode: 

17 if return_mask: 

18 return X, numpy.ones(X.shape, dtype=bool) 

19 return (X, ) 

20 

21 rnd = RandomState(seed) 

22 mask = rnd.uniform(0, 1.0, X.shape) >= drop_probability 

23 scale = (1. / (1. - drop_probability)) 

24 return ( 

25 (mask * X * scale, mask.astype(bool)) 

26 if return_mask else (mask * X * scale, )) 

27 

28 

29class DropoutBase(OpRun): 

30 

31 def __init__(self, onnx_node, desc=None, expected_attributes=None, **options): 

32 OpRun.__init__(self, onnx_node, desc=desc, 

33 expected_attributes=expected_attributes, 

34 **options) 

35 self.nb_outputs = len(onnx_node.output) 

36 

37 def _private_run(self, X, seed=None, ratio=0.5, training_mode=False): # pylint: disable=W0221 

38 return _dropout(X, ratio, seed=seed, return_mask=self.nb_outputs == 2, 

39 training_mode=training_mode) 

40 

41 def _infer_shapes(self, *inputs): # pylint: disable=W0221 

42 X = inputs[0] 

43 if self.nb_outputs == 1: 

44 return (X.copy(), ) 

45 if self.nb_outputs == 2: 

46 return (X.copy(), X.copy()) 

47 raise RuntimeError( # pragma: no cover 

48 "Unexpected numbers of output {} > 2.".format(self.nb_outputs)) 

49 

50 def _infer_types(self, *inputs): # pylint: disable=W0221 

51 X = inputs[0] 

52 if self.nb_outputs == 1: 

53 return (X, ) 

54 if self.nb_outputs == 2: 

55 return (X, X) 

56 raise RuntimeError( # pragma: no cover 

57 "Unexpected numbers of output {} > 2.".format(self.nb_outputs)) 

58 

59 def _infer_sizes(self, *inputs): # pylint: disable=W0221 

60 res = self.run(*inputs) 

61 x = inputs[0] 

62 return (dict(temp=x.size * ( 

63 x.dtype.itemsize + numpy.bool_(True).itemsize)), ) + res 

64 

65 

66class Dropout_7(DropoutBase): 

67 

68 atts = {'ratio': 0.5} 

69 

70 def __init__(self, onnx_node, desc=None, **options): 

71 DropoutBase.__init__(self, onnx_node, desc=desc, 

72 expected_attributes=Dropout_7.atts, 

73 **options) 

74 

75 def _run(self, X): # pylint: disable=W0221 

76 return self._private_run(X, self.ratio) 

77 

78 

79class Dropout_12(DropoutBase): 

80 

81 atts = {'seed': 0} 

82 

83 def __init__(self, onnx_node, desc=None, **options): 

84 DropoutBase.__init__(self, onnx_node, desc=desc, 

85 expected_attributes=Dropout_12.atts, 

86 **options) 

87 

88 def _run(self, *inputs): # pylint: disable=W0221 

89 X = inputs[0] 

90 ratio = 0.5 if len(inputs) <= 1 else inputs[1] 

91 training_mode = False if len(inputs) <= 2 else inputs[2] 

92 return self._private_run(X, seed=self.seed, ratio=ratio, 

93 training_mode=training_mode) 

94 

95 

96if onnx_opset_version() >= 12: 

97 Dropout = Dropout_12 

98else: 

99 Dropout = Dropout_7 # pragma: no cover