Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRunUnaryNum 

9 

10 

11class Transpose(OpRunUnaryNum): 

12 

13 atts = {'perm': []} 

14 

15 def __init__(self, onnx_node, desc=None, **options): 

16 OpRunUnaryNum.__init__(self, onnx_node, desc=desc, 

17 expected_attributes=Transpose.atts, 

18 **options) 

19 self.perm_ = None if len(self.perm) == 0 else self.perm 

20 

21 def _run(self, data): # pylint: disable=W0221 

22 if self.perm_ is None: 

23 return (numpy.transpose(data), ) 

24 if len(self.perm_) != len(data.shape): 

25 raise RuntimeError( 

26 "Inconsistent permutation %r with shape %r." % ( 

27 self.perm_, data.shape)) 

28 return (numpy.transpose(data, axes=self.perm_), ) 

29 

30 def _infer_shapes(self, x): # pylint: disable=W0221 

31 return (x.transpose(perm=self.perm), ) 

32 

33 def to_python(self, inputs): 

34 """ 

35 Returns a python code equivalent to this operator. 

36 

37 @param inputs inputs name 

38 @return imports, python code, both as strings 

39 """ 

40 lines = [ 

41 "if perm is None:", 

42 " return numpy.transpose(%s)" % inputs[0], 

43 "return numpy.transpose(%s, axes=perm)" % inputs[0] 

44 ] 

45 return "import numpy", "\n".join(lines)