Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Functions to help guessing the final graph structure. 

4""" 

5import numpy 

6try: 

7 from onnxconverter_common.data_types import Float16TensorType 

8except ImportError: # pragma: no cover 

9 Float16TensorType = None 

10from skl2onnx.common.data_types import ( 

11 DataType, 

12 FloatTensorType, SequenceType, DictionaryType, 

13 Int64Type, Int64TensorType, BooleanTensorType, 

14 Int32TensorType, DoubleTensorType, FloatType, 

15 StringTensorType) 

16from skl2onnx.common.data_types import _guess_type_proto 

17from skl2onnx.algebra.type_helper import _guess_type as skl2onnx__guess_type 

18from skl2onnx.proto import TensorProto 

19 

20 

21def _guess_type(var): 

22 if isinstance(var, dict) and 'value' in var: 

23 return skl2onnx__guess_type(var['value']) # pragma: no cover 

24 return skl2onnx__guess_type(var) 

25 

26 

27def get_defined_inputs(input_names, variables=None, dtype=None): 

28 """ 

29 Retrieves defined inputs in already declared variables 

30 bsed on their names. 

31 

32 @param input_names input names 

33 @param variables registered variables created 

34 by previous operators 

35 @param dtype float computational type 

36 @return typed inputs 

37 as ``tuple(name, type)`` 

38 """ 

39 def guess_type_variable(name): 

40 if variables is None: 

41 return ( # pragma: no cover 

42 DoubleTensorType() if dtype == numpy.float64 else FloatTensorType()) 

43 elif name in variables: 

44 ty = variables[name] 

45 if isinstance(ty, DataType): 

46 shape = ty.shape 

47 if 0 in shape: 

48 raise RuntimeError( # pragma: no cover 

49 "Shape cannot be empty: name='{}', var={}".format( 

50 name, ty)) 

51 return variables[name] 

52 if isinstance(ty, dict) and 'value' in ty: 

53 # constant 

54 arr = ty['value'] 

55 try: 

56 return _guess_type(arr) 

57 except RuntimeError as e: # pragma: no cover 

58 raise RuntimeError( 

59 "Unable to guess type of variable '{}' - {}." 

60 "".format(name, arr)) from e 

61 raise NotImplementedError( # pragma: no cover 

62 "Unable to guess type for '{}' form '{}'.".format( 

63 name, variables[name])) 

64 else: 

65 # Inputs. Let's assume it is a vector of floats. 

66 return DoubleTensorType() if dtype == numpy.float64 else FloatTensorType() 

67 

68 inputs = [(name, guess_type_variable(name)) for name in input_names] 

69 return inputs 

70 

71 

72def get_defined_outputs(outputs, onnx_node, typed_inputs=None, variables=None, dtype=None): 

73 """ 

74 Gets types of predefined outputs when they cannot be inferred. 

75 Some part of it should be automated based 

76 on type constraints. 

77 

78 @param outputs requested outputs 

79 @param onnx_node :epkg:`ONNX` node definition 

80 @param typed_inputs known typed inputs of the node 

81 as ``tuple(name, type)`` 

82 @param variables registered variables created 

83 by previous operators 

84 @param dtype float computational type 

85 @return typed outputs 

86 as ``tuple(name, type)`` 

87 """ 

88 ft = DoubleTensorType if dtype == numpy.float64 else FloatTensorType 

89 

90 # ZipMap 

91 if onnx_node.op_type == "ZipMap": 

92 otype = SequenceType(DictionaryType( 

93 Int64Type(), ft())) 

94 outputs = [(name, otype) for name in outputs] 

95 # ArgMin, ArgMax, Shape 

96 elif onnx_node.op_type in ("ArgMin", "ArgMax", 'Shape') and len(outputs) == 1: 

97 outputs = [(outputs[0], Int64TensorType())] 

98 # Greater, Less, Equal 

99 elif onnx_node.op_type in ("Greater", "Less", 'Equal') and len(outputs) == 1: 

100 outputs = [(outputs[0], BooleanTensorType())] 

101 # TopK 

102 elif onnx_node.op_type == "TopK" and len(outputs) == 2: 

103 if len(typed_inputs) != 2: 

104 raise RuntimeError( # pragma: no cover 

105 "Wrong typed_inputs, got {}.".format(typed_inputs)) 

106 outputs = [(outputs[0], typed_inputs[0][1]), 

107 (outputs[1], Int64TensorType())] 

108 # Cast 

109 elif onnx_node.op_type == "Cast" and len(outputs) == 1: 

110 ttyp = _guess_type_proto(onnx_node.attribute[0].i, dims=None) 

111 outputs = [(outputs[0], ttyp)] 

112 # ArrayFeatureExtractor 

113 elif onnx_node.op_type == "ArrayFeatureExtractor": 

114 if len(typed_inputs) != 2: 

115 raise RuntimeError( # pragma: no cover 

116 "Wrong typed_inputs, got {}.".format(typed_inputs)) 

117 outputs = [(outputs[0], typed_inputs[0][1])] 

118 elif 'Classifier' in onnx_node.op_type: 

119 # Good chance that's a classifier. 

120 outputs = [(outputs[0], Int64TensorType()), 

121 (outputs[1], ft())] 

122 # Reshape 

123 elif onnx_node.op_type in ('Reshape', 'Transpose'): 

124 outputs = [(outputs[0], typed_inputs[0][1].__class__())] 

125 # Scan 

126 elif onnx_node.op_type == 'Scan': 

127 if len(outputs) != len(typed_inputs): 

128 raise RuntimeError( # pragma: no cover 

129 "Dimension mismatch, operator Scan should have " 

130 "the same number of inputs and outputs {} != {}" 

131 ".".format(len(outputs), len(typed_inputs))) 

132 outputs = [(o, t[1].__class__()) 

133 for o, t in zip(outputs, typed_inputs)] 

134 # ConstantOfShape 

135 elif onnx_node.op_type == "ConstantOfShape": 

136 outputs = [(outputs[0], ft())] 

137 

138 # Default case 

139 # Assuming the only output is the same as the only input. 

140 elif len(typed_inputs) == 1 and len(outputs) == 1: 

141 outputs = [(outputs[0], typed_inputs[0][1])] 

142 # Default 

143 else: 

144 outputs = [(name, ft()) for name in outputs] 

145 return outputs 

146 

147 

148def proto2vars(values): 

149 """ 

150 Converts proto values to Variables. 

151 """ 

152 def ptype2vttype(it, shape): 

153 if it == TensorProto.FLOAT: # pylint: disable=E1101 

154 return FloatTensorType(shape) 

155 if it == TensorProto.DOUBLE: # pylint: disable=E1101 

156 return DoubleTensorType(shape) 

157 if it == TensorProto.INT64: # pylint: disable=E1101 

158 return Int64TensorType(shape) 

159 if it == TensorProto.INT32: # pylint: disable=E1101 

160 return Int32TensorType(shape) 

161 if it == TensorProto.BOOL: # pylint: disable=E1101 

162 return BooleanTensorType(shape) 

163 if it == TensorProto.STRING: # pylint: disable=E1101 

164 return StringTensorType(shape) 

165 if Float16TensorType is None: 

166 if it == TensorProto.FLOAT16: # pylint: disable=E1101 

167 return Float16TensorType(shape) 

168 raise NotImplementedError( # pragma: no cover 

169 "Unrecognized proto type {} with shape {}".format(it, shape)) 

170 

171 def ptype2vtype(it): 

172 if it == TensorProto.FLOAT: # pylint: disable=E1101 

173 return FloatType() 

174 if it == TensorProto.INT64: # pylint: disable=E1101 

175 return Int64Type() 

176 raise NotImplementedError( # pragma: no cover 

177 "Unrecognized proto type {}".format(it)) 

178 

179 res = [] 

180 for v_ in values: 

181 v = v_ 

182 name = v.name if hasattr(v, 'name') else None 

183 if hasattr(v, 'type') and str(v.type) != '': 

184 t = v.type 

185 v = proto2vars([t])[0][1] 

186 elif hasattr(v, 'sequence_type') and str(v.sequence_type) != '': 

187 subtype = proto2vars([v.sequence_type.elem_type])[0][1] 

188 v = SequenceType(subtype) 

189 elif hasattr(v, 'tensor_type') and str(v.tensor_type) != '': 

190 tt = v.tensor_type 

191 el = tt.elem_type 

192 shape = tt.shape 

193 dim = shape.dim 

194 if len(dim) == 0: 

195 shape = [] 

196 else: 

197 shape = [dim[i].dim_value for i in range(len(dim))] 

198 v = ptype2vttype(el, shape) 

199 elif hasattr(v, 'map_type') and str(v.map_type) != '': 

200 mt = v.map_type 

201 keyt = ptype2vtype(mt.key_type) 

202 valt = proto2vars([mt.value_type])[0][1] 

203 v = DictionaryType(keyt, valt) 

204 else: 

205 raise RuntimeError( # pragma: no cover 

206 "Unable to build a variable from {}.".format(v)) 

207 if v.shape is not None and 0 in v.shape: 

208 # Replaces 0 by None 

209 new_shape = tuple(None if d == 0 else d for d in v.shape) 

210 if new_shape in ((None, ), None): 

211 v = v.__class__() 

212 else: 

213 v = v.__class__(new_shape) 

214 if v.shape is not None and 0 in v.shape: 

215 raise RuntimeError( # pragma: no cover 

216 "Shape cannot be empty: '{}': {}.".format( 

217 name, v_)) 

218 res.append((name, v)) 

219 return res