Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7from ._op import OpRun 

8 

9 

10class If(OpRun): 

11 

12 atts = { 

13 'then_branch': None, 

14 'else_branch': None, 

15 } 

16 

17 def __init__(self, onnx_node, desc=None, **options): 

18 OpRun.__init__(self, onnx_node, desc=desc, 

19 expected_attributes=If.atts, 

20 **options) 

21 if not hasattr(self.then_branch, 'run'): 

22 raise RuntimeError( # pragma: no cover 

23 "Parameter 'then_branch' must have a method 'run', " 

24 "type {}.".format(type(self.then_branch))) 

25 if not hasattr(self.else_branch, 'run'): 

26 raise RuntimeError( # pragma: no cover 

27 "Parameter 'else_branch' must have a method 'run', " 

28 "type {}.".format(type(self.else_branch))) 

29 

30 self._run_meth_then = (self.then_branch.run_in_scan 

31 if hasattr(self.then_branch, 'run_in_scan') 

32 else self.then_branch.run) 

33 self._run_meth_else = (self.else_branch.run_in_scan 

34 if hasattr(self.else_branch, 'run_in_scan') 

35 else self.else_branch.run) 

36 

37 def _run(self, cond, named_inputs=None): # pylint: disable=W0221 

38 if named_inputs is None: 

39 named_inputs = {} 

40 if len(self.then_branch.input_names) > 0: 

41 if len(named_inputs) == 0: 

42 raise RuntimeError( # pragma: no cover 

43 "named_inputs is empty but the graph needs {}.".format( 

44 self.then_branch.input_names)) 

45 for k in self.then_branch.input_names: 

46 if k not in named_inputs: 

47 raise RuntimeError( # pragma: no cover 

48 "Unable to find named input '{}' in\n{}.".format( 

49 k, "\n".join(sorted(named_inputs)))) 

50 if len(self.else_branch.input_names) > 0: 

51 if len(named_inputs) == 0: 

52 raise RuntimeError( # pragma: no cover 

53 "named_inputs is empty but the graph needs {}.".format( 

54 self.then_branch.input_names)) 

55 for k in self.else_branch.input_names: 

56 if k not in named_inputs: 

57 raise RuntimeError( # pragma: no cover 

58 "Unable to find named input '{}' in\n{}.".format( 

59 k, "\n".join(sorted(named_inputs)))) 

60 

61 if all(cond): 

62 outputs = self._run_meth_then(named_inputs) 

63 return tuple([outputs[name] for name in self.then_branch.output_names]) 

64 outputs = self._run_meth_else(named_inputs) 

65 return tuple([outputs[name] for name in self.else_branch.output_names]) 

66 

67 def _infer_shapes(self, cond, named_inputs=None): # pylint: disable=W0221 

68 res = self.then_branch._set_shape_inference_runtime() 

69 return tuple([res[name] for name in self.then_branch.output_names]) 

70 

71 def _infer_types(self, cond, named_inputs=None): # pylint: disable=W0221 

72 res = self.then_branch._set_type_inference_runtime() 

73 return tuple([res[name] for name in self.then_branch.output_names])