Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from onnx.defs import onnx_opset_version 

9from ._op import OpRun 

10 

11 

12def _batchnorm_test_mode(x, s, bias, mean, var, epsilon=1e-5): 

13 dims_x = len(x.shape) 

14 dim_ones = (1,) * (dims_x - 2) 

15 s = s.reshape(-1, *dim_ones) 

16 bias = bias.reshape(-1, *dim_ones) 

17 mean = mean.reshape(-1, *dim_ones) 

18 var = var.reshape(-1, *dim_ones) 

19 y = s * (x - mean) / numpy.sqrt(var + epsilon) + bias 

20 return y.astype(x.dtype) 

21 

22 

23def _batchnorm_training_mode(x, s, bias, mean, var, momentum=0.9, 

24 epsilon=1e-5): 

25 axis = tuple(numpy.delete(numpy.arange(len(x.shape)), 1)) 

26 saved_mean = x.mean(axis=axis) 

27 saved_var = x.var(axis=axis) 

28 output_mean = mean * momentum + saved_mean * (1 - momentum) 

29 output_var = var * momentum + saved_var * (1 - momentum) 

30 y = _batchnorm_test_mode(x, s, bias, saved_mean, saved_var, 

31 epsilon=epsilon) 

32 return (y.astype(x.dtype), saved_mean.astype(x.dtype), 

33 saved_var.astype(x.dtype), output_mean.astype(x.dtype), 

34 output_var.astype(x.dtype)) 

35 

36 

37class BatchNormalization_9(OpRun): 

38 

39 atts = {'epsilon': 1e-5, 'momentum': 0.9} 

40 

41 def __init__(self, onnx_node, desc=None, **options): 

42 OpRun.__init__(self, onnx_node, desc=desc, 

43 expected_attributes=BatchNormalization.atts, 

44 **options) 

45 

46 def _run(self, x, scale, bias, mean, var): # pylint: disable=W0221 

47 res = _batchnorm_test_mode( 

48 x, scale, bias, mean, var, epsilon=self.epsilon) 

49 return (res, ) 

50 

51 def _infer_shapes(self, x, scale, bias, mean, var): # pylint: disable=W0221 

52 return (x, ) 

53 

54 def _infer_types(self, x, scale, bias, mean, var): # pylint: disable=W0221 

55 return (x, ) 

56 

57 def _infer_sizes(self, x, scale, bias, mean, var): # pylint: disable=W0221 

58 res = self.run(x, scale, bias, mean, var) 

59 return (dict(temp=x.size * x.dtype.itemsize * 2), ) + res 

60 

61 

62class BatchNormalization_14(OpRun): 

63 

64 atts = {'epsilon': 1e-5, 'momentum': 0.9, 'training_mode': 0} 

65 

66 def __init__(self, onnx_node, desc=None, **options): 

67 OpRun.__init__(self, onnx_node, desc=desc, 

68 expected_attributes=BatchNormalization.atts, 

69 **options) 

70 

71 def _run(self, x, scale, bias, mean, var): # pylint: disable=W0221 

72 if self.training_mode == 0: 

73 res = _batchnorm_test_mode( 

74 x, scale, bias, mean, var, epsilon=self.epsilon) 

75 return (res, ) 

76 res, saved_mean, saved_var, output_mean, output_var = ( 

77 _batchnorm_training_mode(x, scale, bias, mean, var, 

78 self.momentum, self.epsilon)) 

79 return res, saved_mean, saved_var, output_mean, output_var 

80 

81 def _infer_shapes(self, x, scale, bias, mean, var): # pylint: disable=W0221 

82 if self.training_mode == 0: 

83 return (x, ) 

84 return (x, scale, bias, mean, var) 

85 

86 def _infer_types(self, x, scale, bias, mean, var): # pylint: disable=W0221 

87 if self.training_mode == 0: 

88 return (x, ) 

89 return (x, scale, bias, mean, var) 

90 

91 def _infer_sizes(self, x, scale, bias, mean, var): # pylint: disable=W0221 

92 if self.training_mode == 0: 

93 res = self.run(x, scale, bias, mean, var) 

94 return (dict(temp=x.size * x.dtype.itemsize * 2), ) + res 

95 res = self.run(x, scale, bias, mean, var) 

96 return (dict(temp=x.size * x.dtype.itemsize * 4), ) + res 

97 

98 

99if onnx_opset_version() >= 14: 

100 BatchNormalization = BatchNormalization_14 

101else: # pragma: no cover 

102 BatchNormalization = BatchNormalization_9