Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from onnx.defs import onnx_opset_version 

9from ._op import OpRun 

10from ..shape_object import ShapeObject 

11 

12 

13def reshape_reference_implementation(data, shape): 

14 new_shape = numpy.copy(shape) 

15 zeros_index = numpy.where(shape == 0) 

16 if len(data.shape) == 1 and data.shape[0] == 0: 

17 reshaped = numpy.reshape(data, shape) 

18 else: 

19 try: 

20 new_shape[zeros_index] = numpy.array(data.shape)[zeros_index] 

21 except IndexError as e: 

22 raise RuntimeError( 

23 "Unable to reshape from shape %r to shape %r (or %r)." 

24 "" % (data.shape, shape, new_shape)) from e 

25 reshaped = numpy.reshape(data, new_shape) 

26 return reshaped 

27 

28 

29class CommonReshape(OpRun): 

30 

31 def __init__(self, onnx_node, desc=None, expected_attributes=None, **options): 

32 OpRun.__init__( 

33 self, onnx_node, desc=desc, 

34 expected_attributes=expected_attributes, **options) 

35 

36 def _run(self, data, shape): # pylint: disable=W0221 

37 return (reshape_reference_implementation(data, shape), ) 

38 

39 def _infer_shapes(self, data, shape): # pylint: disable=W0221 

40 return (ShapeObject(None, dtype=data.dtype), ) 

41 

42 def _infer_types(self, data, shape): # pylint: disable=W0221 

43 return (data, ) 

44 

45 def _infer_sizes(self, *args, **kwargs): 

46 res = self.run(*args, **kwargs) 

47 return (dict(temp=0), ) + res 

48 

49 

50class Reshape_5(CommonReshape): 

51 

52 def __init__(self, onnx_node, desc=None, expected_attributes=None, **options): 

53 CommonReshape.__init__(self, onnx_node, desc=desc, **options) 

54 

55 

56class Reshape_13(Reshape_5): 

57 pass 

58 

59 

60class Reshape_14(CommonReshape): 

61 

62 atts = {'allowzero': 0} 

63 

64 def __init__(self, onnx_node, desc=None, **options): 

65 CommonReshape.__init__( 

66 self, onnx_node, desc=desc, 

67 expected_attributes=Reshape_14.atts, **options) 

68 

69 

70if onnx_opset_version() >= 14: 

71 Reshape = Reshape_14 

72else: 

73 Reshape = Reshape_5