Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# coding: utf-8
2"""
3@file
4@brief Wraps runtime into a :epkg:`scikit-learn` transformer.
5"""
6from io import BytesIO
7import numpy
8import pandas
9import onnx
10from sklearn.base import BaseEstimator, TransformerMixin
11from skl2onnx.algebra.onnx_operator_mixin import OnnxOperatorMixin
12from skl2onnx.helpers.onnx_helper import (
13 load_onnx_model, enumerate_model_node_outputs)
14from skl2onnx.helpers.onnx_helper import select_model_inputs_outputs
15from skl2onnx.common.data_types import (
16 FloatTensorType, DoubleTensorType,
17 Int64TensorType)
18from ..onnx_tools.onnx2py_helper import _var_as_dict, onnx_model_opsets
19from ..onnx_tools.exports.skl2onnx_helper import add_onnx_graph
20from ..onnxrt import OnnxInference
23class OnnxTransformer(BaseEstimator, TransformerMixin, OnnxOperatorMixin):
24 """
25 Calls :epkg:`onnxruntime` or the runtime implemented
26 in this package to transform input based on a ONNX graph.
27 It follows :epkg:`scikit-learn` API
28 so that it can be included in a :epkg:`scikit-learn` pipeline.
29 See notebook :ref:`transferlearningrst` for an example.
31 :param onnx_bytes: bytes
32 :param output_name: string
33 requested output name or None to request all and
34 have method *transform* to store all of them in a dataframe
35 :param enforce_float32: boolean
36 :epkg:`onnxruntime` only supports *float32*,
37 :epkg:`scikit-learn` usually uses double floats, this parameter
38 ensures that every array of double floats is converted into
39 single floats
40 :param runtime: string, defined the runtime to use
41 as described in @see cl OnnxInference.
42 :param change_batch_size: some models are converted for
43 a specific batch size, this parameter changes it,
44 None to avoid changing it, 0 to fix an undefined
45 first dimension
46 :param reshape: reshape the output to get
47 a matrix and not a multidimensional array
48 """
50 def __init__(self, onnx_bytes, output_name=None, enforce_float32=True,
51 runtime='python', change_batch_size=None, reshape=False):
52 BaseEstimator.__init__(self)
53 TransformerMixin.__init__(self)
54 self.onnx_bytes = (onnx_bytes
55 if not hasattr(onnx_bytes, 'SerializeToString')
56 else onnx_bytes.SerializeToString())
57 self.output_name = output_name
58 self.enforce_float32 = enforce_float32
59 self.runtime = runtime
60 self.change_batch_size = change_batch_size
61 self.reshape = reshape
63 def __repr__(self): # pylint: disable=W0222
64 """
65 usual
66 """
67 ob = self.onnx_bytes
68 if len(ob) > 20:
69 ob = ob[:10] + b"..." + ob[-10:]
70 return ("{0}(onnx_bytes={1}, output_name={2}, enforce_float32={3}, "
71 "runtime='{4}')".format(
72 self.__class__.__name__, ob, self.output_name,
73 self.enforce_float32, self.runtime))
75 def fit(self, X=None, y=None, **fit_params):
76 """
77 Loads the :epkg:`ONNX` model.
79 :param X: unused
80 :param y: unused
81 :param fit_params: additional parameter (unused)
82 :return: self
83 """
84 from ..onnx_tools.optim.onnx_helper import change_input_first_dimension
85 onx = onnx.load(BytesIO(self.onnx_bytes))
86 self.op_version = onnx_model_opsets(onx)
88 output_names = set(
89 o.name for o in onx.graph.output) # pylint: disable=E1101
90 updated = False
91 if (self.output_name is not None and
92 self.output_name not in output_names):
93 # The model refers to intermediate outputs.
94 onx = select_model_inputs_outputs(
95 onx, outputs=[self.output_name])
96 updated = True
98 if self.change_batch_size is not None:
99 onx = change_input_first_dimension(
100 onx, self.change_batch_size)
101 updated = True
103 onnx_bytes = (
104 onx.SerializeToString() if updated else self.onnx_bytes)
105 self.onnxrt_ = OnnxInference(onnx_bytes, runtime=self.runtime)
106 self.inputs_ = self.onnxrt_.input_names
107 self.inputs_shape_types_ = self.onnxrt_.input_names_shapes_types
108 return self
110 def _check_arrays(self, inputs):
111 """
112 Ensures that double floats are converted into single floats
113 if *enforce_float32* is True or raises an exception.
114 """
115 has = hasattr(self, "onnxrt_")
116 sht = self.inputs_shape_types_ if has else None
117 if sht is not None and len(sht) < len(inputs):
118 raise RuntimeError( # pragma: no cover
119 "Unexpected number of inputs {} > {} (expected).".format(
120 len(inputs), len(sht)))
121 for i, k in enumerate(inputs):
122 v = inputs[k]
123 if isinstance(v, numpy.ndarray):
124 if v.dtype == numpy.float64 and self.enforce_float32:
125 inputs[k] = v.astype(numpy.float32)
126 continue
127 if not has:
128 continue
129 exp = sht[i]
130 if exp[1] != ('?', ) and exp[1][1:] != v.shape[1:]:
131 raise RuntimeError( # pragma: no cover
132 "Unexpected shape for input '{}': {} != {} "
133 "(expected).".format(
134 k, v.shape, exp[1]))
135 if ((v.dtype == numpy.float32 and exp[2] != 'tensor(float)') or
136 (v.dtype == numpy.float64 and exp[2] != 'tensor(double)')):
137 raise TypeError( # pragma: no cover
138 "Unexpected dtype for input '{}': {} != {} "
139 "(expected).".format(
140 k, v.dtype, exp[2]))
142 def transform(self, X, y=None, **inputs):
143 """
144 Runs the predictions. If *X* is a dataframe,
145 the function assumes every columns is a separate input,
146 otherwise, *X* is considered as a first input and *inputs*
147 can be used to specify extra inputs.
149 :param X: iterable, data to process
150 (or first input if several expected)
151 :param y: unused
152 :param inputs: :epkg:`ONNX` graph support multiple inputs,
153 each column of a dataframe is converted into as many inputs if
154 *X* is a dataframe, otherwise, *X* is considered as the first input
155 and *inputs* can be used to specify the other ones
156 :return: :epkg:`DataFrame`
157 """
158 if not hasattr(self, "onnxrt_"):
159 raise AttributeError( # pragma: no cover
160 "Transform OnnxTransformer must be fit first.")
161 rt_inputs = {}
162 if isinstance(X, numpy.ndarray):
163 rt_inputs[self.inputs_[0]] = X
164 elif isinstance(X, pandas.DataFrame):
165 for c in X.columns:
166 rt_inputs[c] = X[c]
167 elif isinstance(X, dict) and len(inputs) == 0:
168 for k, v in X.items():
169 rt_inputs[k] = v
170 elif isinstance(X, list):
171 if len(self.inputs_) == 1:
172 rt_inputs[self.inputs_[0]] = numpy.array(X)
173 else:
174 for i in range(len(self.inputs_)): # pylint: disable=C0200
175 rt_inputs[self.inputs_[i]] = [row[i] for row in X]
177 for k, v in inputs.items():
178 rt_inputs[k] = v
180 names = ([self.output_name]
181 if self.output_name else self.onnxrt_.output_names)
182 self._check_arrays(rt_inputs)
183 doutputs = self.onnxrt_.run(rt_inputs)
184 outputs = [doutputs[n] for n in names]
186 if self.reshape:
187 n = outputs[0].shape[0]
188 outputs = [o.reshape((n, -1)) for o in outputs]
190 if self.output_name or len(outputs) == 1:
191 if isinstance(outputs[0], list):
192 return pandas.DataFrame(outputs[0])
193 return outputs[0]
195 names = self.output_name if self.output_name else [
196 o for o in self.onnxrt_.output_names]
197 concat = []
198 colnames = []
199 for k, v in zip(names, outputs):
200 if isinstance(v, numpy.ndarray):
201 if len(v.shape) == 1:
202 v = v.reshape((-1, 1))
203 colnames.append(k)
204 elif len(v.shape) == 2:
205 colnames.extend("%s%d" % (k, i) for i in range(v.shape[1]))
206 else:
207 raise RuntimeError( # pragma: no cover
208 "Unexpected shape for results %r: %r." % (k, v.shape))
209 if isinstance(v, list):
210 if len(v) == 0:
211 raise RuntimeError( # pragma: no cover
212 "Output %r is empty." % k)
213 if not isinstance(v[0], dict):
214 raise RuntimeError( # pragma: no cover
215 "Unexpected type for output %r - value=%r."
216 "" % (k, v[0]))
217 df = pandas.DataFrame(v)
218 cols = list(sorted(df.columns))
219 v = df[cols].copy().values
220 colnames.extend("%s%d" % (k, i) for i in range(v.shape[1]))
221 concat.append(v)
222 res = numpy.hstack(concat)
223 return pandas.DataFrame(res, columns=colnames)
225 def fit_transform(self, X, y=None, **inputs):
226 """
227 Loads the *ONNX* model and runs the predictions.
229 :param X: iterable, data to process
230 (or first input if several expected)
231 :param y: unused
232 :param inputs: :epkg:`ONNX` graph support multiple inputs,
233 each column of a dataframe is converted into as many inputs if
234 *X* is a dataframe, otherwise, *X* is considered as the first input
235 and *inputs* can be used to specify the other ones
236 :return: :epkg:`DataFrame`
237 """
238 return self.fit(X, y=y, **inputs).transform(X, y)
240 @staticmethod
241 def enumerate_create(onnx_bytes, output_names=None, enforce_float32=True):
242 """
243 Creates multiple *OnnxTransformer*,
244 one for each requested intermediate node.
246 onnx_bytes : bytes
247 output_names: string
248 requested output names or None to request all and
249 have method *transform* to store all of them in a dataframe
250 enforce_float32 : boolean
251 :epkg:`onnxruntime` only supports *float32*,
252 :epkg:`scikit-learn` usually uses double floats, this parameter
253 ensures that every array of double floats is converted into
254 single floats
255 :return: iterator on OnnxTransformer *('output name', OnnxTransformer)*
256 """
257 selected = None if output_names is None else set(output_names)
258 model = load_onnx_model(onnx_bytes)
259 for out in enumerate_model_node_outputs(model):
260 m = select_model_inputs_outputs(model, out)
261 if selected is None or out in selected:
262 tr = OnnxTransformer(m.SerializeToString(),
263 enforce_float32=enforce_float32)
264 yield out, tr
266 def onnx_parser(self):
267 """
268 Returns a parser for this model.
269 """
270 def parser(scope=None, inputs=None):
271 if scope is None:
272 raise RuntimeError(
273 "scope cannot be None (parser of class %r)."
274 "" % type(self))
275 if inputs is None:
276 raise RuntimeError(
277 "inputs cannot be None (parser of class %r)."
278 "" % type(self))
279 if (not hasattr(self, 'onnxrt_') or
280 not hasattr(self.onnxrt_, 'output_names')):
281 raise RuntimeError( # pragma: no cover
282 'OnnxTransformer not fit.')
283 if len(inputs) != len(self.inputs_):
284 raise RuntimeError( # pragma: no cover
285 "Mismatch between the number of inputs, expected %r, "
286 "got %r." % (self.inputs_, inputs))
287 return self.onnxrt_.output_names
288 return parser
290 def onnx_shape_calculator(self):
291 def shape_calculator(operator):
292 cout = self.onnxrt_.output_names
293 if len(operator.outputs) != len(cout):
294 raise RuntimeError( # pragma: no cover
295 "Mismatched number of outputs: {} != {}."
296 "".format(len(operator.outputs), len(cout)))
297 for out_op, out in zip(operator.outputs, self.onnxrt_.obj.graph.output):
298 var = _var_as_dict(out)
299 if var['type']['kind'] != 'tensor':
300 raise NotImplementedError( # pragma: no cover
301 "Noy yet implemented for output:\n{}".format(out))
302 shape = var['type']['shape']
303 if shape[0] == 0:
304 shape = (None,) + tuple(shape[1:])
305 elem = var['type']['elem']
306 if elem == 'float':
307 out_op.type = FloatTensorType(shape=shape)
308 elif elem == 'int64':
309 out_op.type = Int64TensorType(shape=shape)
310 elif elem == 'double':
311 out_op.type = DoubleTensorType(shape=shape)
312 else:
313 raise NotImplementedError( # pragma: no cover
314 "Not yet implemented for elem_type:\n{}".format(elem))
315 return shape_calculator
317 def onnx_converter(self):
318 """
319 Returns a converter for this model.
320 If not overloaded, it fetches the converter
321 mapped to the first *scikit-learn* parent
322 it can find.
323 """
324 def converter(scope, operator, container, onnx_model=None):
325 op = operator.raw_operator
326 onx = onnx_model or op.onnxrt_.obj
327 add_onnx_graph(scope, operator, container, onx)
329 return converter
331 @property
332 def opsets(self):
333 """
334 Returns the opsets as dictionary ``{domain: opset}``.
335 """
336 if hasattr(self, 'onnxrt_'):
337 model = self.onnxrt_.obj
338 else:
339 model = load_onnx_model(self.onnx_bytes)
340 res = {}
341 for oimp in model.opset_import:
342 res[oimp.domain] = oimp.version
343 return res