Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Functions to help guessing the final graph structure.
4"""
5import numpy
6try:
7 from onnxconverter_common.data_types import Float16TensorType
8except ImportError: # pragma: no cover
9 Float16TensorType = None
10from skl2onnx.common.data_types import (
11 DataType,
12 FloatTensorType, SequenceType, DictionaryType,
13 Int64Type, Int64TensorType, BooleanTensorType,
14 Int32TensorType, DoubleTensorType, FloatType,
15 StringTensorType)
16from skl2onnx.common.data_types import _guess_type_proto
17from skl2onnx.algebra.type_helper import _guess_type as skl2onnx__guess_type
18from skl2onnx.proto import TensorProto
21def _guess_type(var):
22 if isinstance(var, dict) and 'value' in var:
23 return skl2onnx__guess_type(var['value']) # pragma: no cover
24 return skl2onnx__guess_type(var)
27def get_defined_inputs(input_names, variables=None, dtype=None):
28 """
29 Retrieves defined inputs in already declared variables
30 bsed on their names.
32 @param input_names input names
33 @param variables registered variables created
34 by previous operators
35 @param dtype float computational type
36 @return typed inputs
37 as ``tuple(name, type)``
38 """
39 def guess_type_variable(name):
40 if variables is None:
41 return ( # pragma: no cover
42 DoubleTensorType() if dtype == numpy.float64 else FloatTensorType())
43 elif name in variables:
44 ty = variables[name]
45 if isinstance(ty, DataType):
46 shape = ty.shape
47 if 0 in shape:
48 raise RuntimeError( # pragma: no cover
49 "Shape cannot be empty: name='{}', var={}".format(
50 name, ty))
51 return variables[name]
52 if isinstance(ty, dict) and 'value' in ty:
53 # constant
54 arr = ty['value']
55 try:
56 return _guess_type(arr)
57 except RuntimeError as e: # pragma: no cover
58 raise RuntimeError(
59 "Unable to guess type of variable '{}' - {}."
60 "".format(name, arr)) from e
61 raise NotImplementedError( # pragma: no cover
62 "Unable to guess type for '{}' form '{}'.".format(
63 name, variables[name]))
64 else:
65 # Inputs. Let's assume it is a vector of floats.
66 return DoubleTensorType() if dtype == numpy.float64 else FloatTensorType()
68 inputs = [(name, guess_type_variable(name)) for name in input_names]
69 return inputs
72def get_defined_outputs(outputs, onnx_node, typed_inputs=None, variables=None, dtype=None):
73 """
74 Gets types of predefined outputs when they cannot be inferred.
75 Some part of it should be automated based
76 on type constraints.
78 @param outputs requested outputs
79 @param onnx_node :epkg:`ONNX` node definition
80 @param typed_inputs known typed inputs of the node
81 as ``tuple(name, type)``
82 @param variables registered variables created
83 by previous operators
84 @param dtype float computational type
85 @return typed outputs
86 as ``tuple(name, type)``
87 """
88 ft = DoubleTensorType if dtype == numpy.float64 else FloatTensorType
90 # ZipMap
91 if onnx_node.op_type == "ZipMap":
92 otype = SequenceType(DictionaryType(
93 Int64Type(), ft()))
94 outputs = [(name, otype) for name in outputs]
95 # ArgMin, ArgMax, Shape
96 elif onnx_node.op_type in ("ArgMin", "ArgMax", 'Shape') and len(outputs) == 1:
97 outputs = [(outputs[0], Int64TensorType())]
98 # Greater, Less, Equal
99 elif onnx_node.op_type in ("Greater", "Less", 'Equal') and len(outputs) == 1:
100 outputs = [(outputs[0], BooleanTensorType())]
101 # TopK
102 elif onnx_node.op_type == "TopK" and len(outputs) == 2:
103 if len(typed_inputs) != 2:
104 raise RuntimeError( # pragma: no cover
105 "Wrong typed_inputs, got {}.".format(typed_inputs))
106 outputs = [(outputs[0], typed_inputs[0][1]),
107 (outputs[1], Int64TensorType())]
108 # Cast
109 elif onnx_node.op_type == "Cast" and len(outputs) == 1:
110 ttyp = _guess_type_proto(onnx_node.attribute[0].i, dims=None)
111 outputs = [(outputs[0], ttyp)]
112 # ArrayFeatureExtractor
113 elif onnx_node.op_type == "ArrayFeatureExtractor":
114 if len(typed_inputs) != 2:
115 raise RuntimeError( # pragma: no cover
116 "Wrong typed_inputs, got {}.".format(typed_inputs))
117 outputs = [(outputs[0], typed_inputs[0][1])]
118 elif 'Classifier' in onnx_node.op_type:
119 # Good chance that's a classifier.
120 outputs = [(outputs[0], Int64TensorType()),
121 (outputs[1], ft())]
122 # Reshape
123 elif onnx_node.op_type in ('Reshape', 'Transpose'):
124 outputs = [(outputs[0], typed_inputs[0][1].__class__())]
125 # Scan
126 elif onnx_node.op_type == 'Scan':
127 if len(outputs) != len(typed_inputs):
128 raise RuntimeError( # pragma: no cover
129 "Dimension mismatch, operator Scan should have "
130 "the same number of inputs and outputs {} != {}"
131 ".".format(len(outputs), len(typed_inputs)))
132 outputs = [(o, t[1].__class__())
133 for o, t in zip(outputs, typed_inputs)]
134 # ConstantOfShape
135 elif onnx_node.op_type == "ConstantOfShape":
136 outputs = [(outputs[0], ft())]
138 # Default case
139 # Assuming the only output is the same as the only input.
140 elif len(typed_inputs) == 1 and len(outputs) == 1:
141 outputs = [(outputs[0], typed_inputs[0][1])]
142 # Default
143 else:
144 outputs = [(name, ft()) for name in outputs]
145 return outputs
148def proto2vars(values):
149 """
150 Converts proto values to Variables.
151 """
152 def ptype2vttype(it, shape):
153 if it == TensorProto.FLOAT: # pylint: disable=E1101
154 return FloatTensorType(shape)
155 if it == TensorProto.DOUBLE: # pylint: disable=E1101
156 return DoubleTensorType(shape)
157 if it == TensorProto.INT64: # pylint: disable=E1101
158 return Int64TensorType(shape)
159 if it == TensorProto.INT32: # pylint: disable=E1101
160 return Int32TensorType(shape)
161 if it == TensorProto.BOOL: # pylint: disable=E1101
162 return BooleanTensorType(shape)
163 if it == TensorProto.STRING: # pylint: disable=E1101
164 return StringTensorType(shape)
165 if Float16TensorType is None:
166 if it == TensorProto.FLOAT16: # pylint: disable=E1101
167 return Float16TensorType(shape)
168 raise NotImplementedError( # pragma: no cover
169 "Unrecognized proto type {} with shape {}".format(it, shape))
171 def ptype2vtype(it):
172 if it == TensorProto.FLOAT: # pylint: disable=E1101
173 return FloatType()
174 if it == TensorProto.INT64: # pylint: disable=E1101
175 return Int64Type()
176 raise NotImplementedError( # pragma: no cover
177 "Unrecognized proto type {}".format(it))
179 res = []
180 for v_ in values:
181 v = v_
182 name = v.name if hasattr(v, 'name') else None
183 if hasattr(v, 'type') and str(v.type) != '':
184 t = v.type
185 v = proto2vars([t])[0][1]
186 elif hasattr(v, 'sequence_type') and str(v.sequence_type) != '':
187 subtype = proto2vars([v.sequence_type.elem_type])[0][1]
188 v = SequenceType(subtype)
189 elif hasattr(v, 'tensor_type') and str(v.tensor_type) != '':
190 tt = v.tensor_type
191 el = tt.elem_type
192 shape = tt.shape
193 dim = shape.dim
194 if len(dim) == 0:
195 shape = []
196 else:
197 shape = [dim[i].dim_value for i in range(len(dim))]
198 v = ptype2vttype(el, shape)
199 elif hasattr(v, 'map_type') and str(v.map_type) != '':
200 mt = v.map_type
201 keyt = ptype2vtype(mt.key_type)
202 valt = proto2vars([mt.value_type])[0][1]
203 v = DictionaryType(keyt, valt)
204 else:
205 raise RuntimeError( # pragma: no cover
206 "Unable to build a variable from {}.".format(v))
207 if v.shape is not None and 0 in v.shape:
208 # Replaces 0 by None
209 new_shape = tuple(None if d == 0 else d for d in v.shape)
210 if new_shape in ((None, ), None):
211 v = v.__class__()
212 else:
213 v = v.__class__(new_shape)
214 if v.shape is not None and 0 in v.shape:
215 raise RuntimeError( # pragma: no cover
216 "Shape cannot be empty: '{}': {}.".format(
217 name, v_))
218 res.append((name, v))
219 return res