Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6 

7.. versionadded:: 0.7 

8""" 

9import numpy 

10from ._op import OpRun 

11from ..shape_object import ShapeObject 

12 

13 

14class Loop(OpRun): 

15 

16 atts = { 

17 'body': None, 

18 } 

19 

20 def __init__(self, onnx_node, desc=None, **options): 

21 OpRun.__init__(self, onnx_node, desc=desc, 

22 expected_attributes=Loop.atts, 

23 **options) 

24 if not hasattr(self.body, 'run'): 

25 raise RuntimeError( # pragma: no cover 

26 "Parameter 'body' must have a method 'run', " 

27 "type {}.".format(type(self.body))) 

28 

29 self._run_meth = (self.body.run_in_scan 

30 if hasattr(self.body, 'run_in_scan') 

31 else self.body.run) 

32 self.additional_inputs = self.body.static_inputs 

33 

34 def need_context(self): 

35 """ 

36 The operator Loop needs to know all results produced 

37 so far as the loop may silently access one of them. 

38 Some information are not always referred in the list of inputs 

39 (kind of static variables). 

40 """ 

41 return len(self.additional_inputs) > 0 

42 

43 def _run(self, M, cond, v_initial, *args, callback=None, context=None): # pylint: disable=W0221 

44 loop_inputs = self.body.input_names 

45 inputs = {name: None for name in loop_inputs} 

46 inputs[loop_inputs[2]] = v_initial 

47 cond_name = self.body.output_names[0] 

48 if len(args) > 0: 

49 begin = len(loop_inputs) - len(args) 

50 all_inputs = loop_inputs[begin:] 

51 for name, val in zip(all_inputs, args): 

52 inputs[name] = val 

53 if len(self.additional_inputs) > 0: 

54 if context is None: 

55 raise RuntimeError( 

56 "Additional inputs %r are missing and context is None." 

57 "" % (self.additional_inputs, )) 

58 for a in self.additional_inputs: 

59 if a in context: 

60 inputs[a] = context[a] 

61 else: 

62 raise RuntimeError( 

63 "Additional inputs %r not found in context\n%s." % ( 

64 a, "\n".join(sorted(map(str, context))))) 

65 

66 it = 0 

67 while cond and it < M: 

68 inputs[self.body.input_names[0]] = numpy.array(it, dtype=M.dtype) 

69 inputs[self.body.input_names[1]] = cond 

70 outputs = self._run_meth(inputs) 

71 cond = outputs[cond_name] 

72 if cond is None: 

73 raise RuntimeError( 

74 "condition %r returned by the subgraph cannot be None." 

75 "" % cond_name) 

76 for i, o in zip(self.body.input_names[2:], 

77 self.body.output_names[1:]): 

78 inputs[i] = outputs[o] 

79 if callback is not None: 

80 callback(inputs, context=context) 

81 it += 1 

82 

83 if it == 0: 

84 outputs = {self.body.output_names[1]: cond} 

85 for i, o in zip(self.body.input_names[2:], 

86 self.body.output_names[1:]): 

87 outputs[o] = inputs[i] 

88 for o in self.body.output_names: 

89 if o not in outputs: 

90 outputs[o] = numpy.empty(shape=tuple()) 

91 res = tuple([outputs[name] for name in self.body.output_names[1:]]) 

92 if any(r is None for r in res): 

93 raise TypeError( # pragma: no cover 

94 "Operator Loop produces a None value.") 

95 return res 

96 

97 def _infer_shapes(self, M, cond, v_initial, *args): # pylint: disable=W0221 

98 res = self.body._set_shape_inference_runtime() 

99 outputs = {k[0]: k[1:] for k in self.body.output_names_shapes_types} 

100 ret = [] 

101 for name in self.body.output_names[1:]: 

102 if name in res: 

103 ret.append(res[name]) 

104 else: 

105 find = outputs[name] 

106 ret.append(ShapeObject(find[0], dtype=find[1])) 

107 return tuple(ret) 

108 

109 def _infer_types(self, M, cond, v_initial, *args): # pylint: disable=W0221 

110 res = self.body._set_type_inference_runtime() 

111 return tuple([res[name] for name in self.body.output_names[1:]]) 

112 

113 def _infer_sizes(self, M, cond, v_initial, *args, context=None): # pylint: disable=W0221 

114 store = [] 

115 

116 def callback_(inputs, context=None): 

117 res = self.body.infer_sizes(inputs, context=context) 

118 store.append(res) 

119 

120 res = self._run(M, cond, v_initial, *args, callback=callback_, 

121 context=context) 

122 

123 temp = 0 

124 for v in store: 

125 for vv in v.values(): 

126 temp += sum(vv.values()) 

127 return (dict(temp=temp), ) + res