Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRunReduceNumpy 

9 

10 

11class ReduceLogSumExp(OpRunReduceNumpy): 

12 

13 atts = {'axes': [], 'keepdims': 1} 

14 

15 def __init__(self, onnx_node, desc=None, **options): 

16 OpRunReduceNumpy.__init__(self, onnx_node, desc=desc, 

17 expected_attributes=ReduceLogSumExp.atts, 

18 **options) 

19 

20 def _run(self, data): # pylint: disable=W0221 

21 tax = tuple(self.axes) if self.axes else None 

22 data_max = data.copy() 

23 ind = numpy.isinf(data_max) 

24 data_max[ind] = -numpy.inf 

25 mx = data_max.max(axis=tax, keepdims=True) 

26 sub = numpy.subtract(data, mx) 

27 exp = numpy.exp(sub, out=sub) 

28 mxs = numpy.sum(exp, axis=tax, 

29 keepdims=True, 

30 dtype=data.dtype) 

31 res = numpy.log(mxs) + mx 

32 if not self.keepdims: 

33 res = numpy.squeeze(res, axis=tax) 

34 return (res, )