Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# coding: utf-8 

2""" 

3@file 

4@brief Wraps runtime into a :epkg:`scikit-learn` transformer. 

5""" 

6from io import BytesIO 

7import numpy 

8import pandas 

9import onnx 

10from sklearn.base import BaseEstimator, TransformerMixin 

11from skl2onnx.algebra.onnx_operator_mixin import OnnxOperatorMixin 

12from skl2onnx.helpers.onnx_helper import ( 

13 load_onnx_model, enumerate_model_node_outputs) 

14from skl2onnx.helpers.onnx_helper import select_model_inputs_outputs 

15from skl2onnx.common.data_types import ( 

16 FloatTensorType, DoubleTensorType, 

17 Int64TensorType) 

18from ..onnx_tools.onnx2py_helper import _var_as_dict, onnx_model_opsets 

19from ..onnx_tools.exports.skl2onnx_helper import add_onnx_graph 

20from ..onnxrt import OnnxInference 

21 

22 

23class OnnxTransformer(BaseEstimator, TransformerMixin, OnnxOperatorMixin): 

24 """ 

25 Calls :epkg:`onnxruntime` or the runtime implemented 

26 in this package to transform input based on a ONNX graph. 

27 It follows :epkg:`scikit-learn` API 

28 so that it can be included in a :epkg:`scikit-learn` pipeline. 

29 See notebook :ref:`transferlearningrst` for an example. 

30 

31 :param onnx_bytes: bytes 

32 :param output_name: string 

33 requested output name or None to request all and 

34 have method *transform* to store all of them in a dataframe 

35 :param enforce_float32: boolean 

36 :epkg:`onnxruntime` only supports *float32*, 

37 :epkg:`scikit-learn` usually uses double floats, this parameter 

38 ensures that every array of double floats is converted into 

39 single floats 

40 :param runtime: string, defined the runtime to use 

41 as described in @see cl OnnxInference. 

42 :param change_batch_size: some models are converted for 

43 a specific batch size, this parameter changes it, 

44 None to avoid changing it, 0 to fix an undefined 

45 first dimension 

46 :param reshape: reshape the output to get 

47 a matrix and not a multidimensional array 

48 """ 

49 

50 def __init__(self, onnx_bytes, output_name=None, enforce_float32=True, 

51 runtime='python', change_batch_size=None, reshape=False): 

52 BaseEstimator.__init__(self) 

53 TransformerMixin.__init__(self) 

54 self.onnx_bytes = (onnx_bytes 

55 if not hasattr(onnx_bytes, 'SerializeToString') 

56 else onnx_bytes.SerializeToString()) 

57 self.output_name = output_name 

58 self.enforce_float32 = enforce_float32 

59 self.runtime = runtime 

60 self.change_batch_size = change_batch_size 

61 self.reshape = reshape 

62 

63 def __repr__(self): # pylint: disable=W0222 

64 """ 

65 usual 

66 """ 

67 ob = self.onnx_bytes 

68 if len(ob) > 20: 

69 ob = ob[:10] + b"..." + ob[-10:] 

70 return ("{0}(onnx_bytes={1}, output_name={2}, enforce_float32={3}, " 

71 "runtime='{4}')".format( 

72 self.__class__.__name__, ob, self.output_name, 

73 self.enforce_float32, self.runtime)) 

74 

75 def fit(self, X=None, y=None, **fit_params): 

76 """ 

77 Loads the :epkg:`ONNX` model. 

78 

79 :param X: unused 

80 :param y: unused 

81 :param fit_params: additional parameter (unused) 

82 :return: self 

83 """ 

84 from ..onnx_tools.optim.onnx_helper import change_input_first_dimension 

85 onx = onnx.load(BytesIO(self.onnx_bytes)) 

86 self.op_version = onnx_model_opsets(onx) 

87 

88 output_names = set( 

89 o.name for o in onx.graph.output) # pylint: disable=E1101 

90 updated = False 

91 if (self.output_name is not None and 

92 self.output_name not in output_names): 

93 # The model refers to intermediate outputs. 

94 onx = select_model_inputs_outputs( 

95 onx, outputs=[self.output_name]) 

96 updated = True 

97 

98 if self.change_batch_size is not None: 

99 onx = change_input_first_dimension( 

100 onx, self.change_batch_size) 

101 updated = True 

102 

103 onnx_bytes = ( 

104 onx.SerializeToString() if updated else self.onnx_bytes) 

105 self.onnxrt_ = OnnxInference(onnx_bytes, runtime=self.runtime) 

106 self.inputs_ = self.onnxrt_.input_names 

107 self.inputs_shape_types_ = self.onnxrt_.input_names_shapes_types 

108 return self 

109 

110 def _check_arrays(self, inputs): 

111 """ 

112 Ensures that double floats are converted into single floats 

113 if *enforce_float32* is True or raises an exception. 

114 """ 

115 has = hasattr(self, "onnxrt_") 

116 sht = self.inputs_shape_types_ if has else None 

117 if sht is not None and len(sht) < len(inputs): 

118 raise RuntimeError( # pragma: no cover 

119 "Unexpected number of inputs {} > {} (expected).".format( 

120 len(inputs), len(sht))) 

121 for i, k in enumerate(inputs): 

122 v = inputs[k] 

123 if isinstance(v, numpy.ndarray): 

124 if v.dtype == numpy.float64 and self.enforce_float32: 

125 inputs[k] = v.astype(numpy.float32) 

126 continue 

127 if not has: 

128 continue 

129 exp = sht[i] 

130 if exp[1] != ('?', ) and exp[1][1:] != v.shape[1:]: 

131 raise RuntimeError( # pragma: no cover 

132 "Unexpected shape for input '{}': {} != {} " 

133 "(expected).".format( 

134 k, v.shape, exp[1])) 

135 if ((v.dtype == numpy.float32 and exp[2] != 'tensor(float)') or 

136 (v.dtype == numpy.float64 and exp[2] != 'tensor(double)')): 

137 raise TypeError( # pragma: no cover 

138 "Unexpected dtype for input '{}': {} != {} " 

139 "(expected).".format( 

140 k, v.dtype, exp[2])) 

141 

142 def transform(self, X, y=None, **inputs): 

143 """ 

144 Runs the predictions. If *X* is a dataframe, 

145 the function assumes every columns is a separate input, 

146 otherwise, *X* is considered as a first input and *inputs* 

147 can be used to specify extra inputs. 

148 

149 :param X: iterable, data to process 

150 (or first input if several expected) 

151 :param y: unused 

152 :param inputs: :epkg:`ONNX` graph support multiple inputs, 

153 each column of a dataframe is converted into as many inputs if 

154 *X* is a dataframe, otherwise, *X* is considered as the first input 

155 and *inputs* can be used to specify the other ones 

156 :return: :epkg:`DataFrame` 

157 """ 

158 if not hasattr(self, "onnxrt_"): 

159 raise AttributeError( # pragma: no cover 

160 "Transform OnnxTransformer must be fit first.") 

161 rt_inputs = {} 

162 if isinstance(X, numpy.ndarray): 

163 rt_inputs[self.inputs_[0]] = X 

164 elif isinstance(X, pandas.DataFrame): 

165 for c in X.columns: 

166 rt_inputs[c] = X[c] 

167 elif isinstance(X, dict) and len(inputs) == 0: 

168 for k, v in X.items(): 

169 rt_inputs[k] = v 

170 elif isinstance(X, list): 

171 if len(self.inputs_) == 1: 

172 rt_inputs[self.inputs_[0]] = numpy.array(X) 

173 else: 

174 for i in range(len(self.inputs_)): # pylint: disable=C0200 

175 rt_inputs[self.inputs_[i]] = [row[i] for row in X] 

176 

177 for k, v in inputs.items(): 

178 rt_inputs[k] = v 

179 

180 names = ([self.output_name] 

181 if self.output_name else self.onnxrt_.output_names) 

182 self._check_arrays(rt_inputs) 

183 doutputs = self.onnxrt_.run(rt_inputs) 

184 outputs = [doutputs[n] for n in names] 

185 

186 if self.reshape: 

187 n = outputs[0].shape[0] 

188 outputs = [o.reshape((n, -1)) for o in outputs] 

189 

190 if self.output_name or len(outputs) == 1: 

191 if isinstance(outputs[0], list): 

192 return pandas.DataFrame(outputs[0]) 

193 return outputs[0] 

194 

195 names = self.output_name if self.output_name else [ 

196 o for o in self.onnxrt_.output_names] 

197 concat = [] 

198 colnames = [] 

199 for k, v in zip(names, outputs): 

200 if isinstance(v, numpy.ndarray): 

201 if len(v.shape) == 1: 

202 v = v.reshape((-1, 1)) 

203 colnames.append(k) 

204 elif len(v.shape) == 2: 

205 colnames.extend("%s%d" % (k, i) for i in range(v.shape[1])) 

206 else: 

207 raise RuntimeError( # pragma: no cover 

208 "Unexpected shape for results %r: %r." % (k, v.shape)) 

209 if isinstance(v, list): 

210 if len(v) == 0: 

211 raise RuntimeError( # pragma: no cover 

212 "Output %r is empty." % k) 

213 if not isinstance(v[0], dict): 

214 raise RuntimeError( # pragma: no cover 

215 "Unexpected type for output %r - value=%r." 

216 "" % (k, v[0])) 

217 df = pandas.DataFrame(v) 

218 cols = list(sorted(df.columns)) 

219 v = df[cols].copy().values 

220 colnames.extend("%s%d" % (k, i) for i in range(v.shape[1])) 

221 concat.append(v) 

222 res = numpy.hstack(concat) 

223 return pandas.DataFrame(res, columns=colnames) 

224 

225 def fit_transform(self, X, y=None, **inputs): 

226 """ 

227 Loads the *ONNX* model and runs the predictions. 

228 

229 :param X: iterable, data to process 

230 (or first input if several expected) 

231 :param y: unused 

232 :param inputs: :epkg:`ONNX` graph support multiple inputs, 

233 each column of a dataframe is converted into as many inputs if 

234 *X* is a dataframe, otherwise, *X* is considered as the first input 

235 and *inputs* can be used to specify the other ones 

236 :return: :epkg:`DataFrame` 

237 """ 

238 return self.fit(X, y=y, **inputs).transform(X, y) 

239 

240 @staticmethod 

241 def enumerate_create(onnx_bytes, output_names=None, enforce_float32=True): 

242 """ 

243 Creates multiple *OnnxTransformer*, 

244 one for each requested intermediate node. 

245 

246 onnx_bytes : bytes 

247 output_names: string 

248 requested output names or None to request all and 

249 have method *transform* to store all of them in a dataframe 

250 enforce_float32 : boolean 

251 :epkg:`onnxruntime` only supports *float32*, 

252 :epkg:`scikit-learn` usually uses double floats, this parameter 

253 ensures that every array of double floats is converted into 

254 single floats 

255 :return: iterator on OnnxTransformer *('output name', OnnxTransformer)* 

256 """ 

257 selected = None if output_names is None else set(output_names) 

258 model = load_onnx_model(onnx_bytes) 

259 for out in enumerate_model_node_outputs(model): 

260 m = select_model_inputs_outputs(model, out) 

261 if selected is None or out in selected: 

262 tr = OnnxTransformer(m.SerializeToString(), 

263 enforce_float32=enforce_float32) 

264 yield out, tr 

265 

266 def onnx_parser(self): 

267 """ 

268 Returns a parser for this model. 

269 """ 

270 def parser(scope=None, inputs=None): 

271 if scope is None: 

272 raise RuntimeError( 

273 "scope cannot be None (parser of class %r)." 

274 "" % type(self)) 

275 if inputs is None: 

276 raise RuntimeError( 

277 "inputs cannot be None (parser of class %r)." 

278 "" % type(self)) 

279 if (not hasattr(self, 'onnxrt_') or 

280 not hasattr(self.onnxrt_, 'output_names')): 

281 raise RuntimeError( # pragma: no cover 

282 'OnnxTransformer not fit.') 

283 if len(inputs) != len(self.inputs_): 

284 raise RuntimeError( # pragma: no cover 

285 "Mismatch between the number of inputs, expected %r, " 

286 "got %r." % (self.inputs_, inputs)) 

287 return self.onnxrt_.output_names 

288 return parser 

289 

290 def onnx_shape_calculator(self): 

291 def shape_calculator(operator): 

292 cout = self.onnxrt_.output_names 

293 if len(operator.outputs) != len(cout): 

294 raise RuntimeError( # pragma: no cover 

295 "Mismatched number of outputs: {} != {}." 

296 "".format(len(operator.outputs), len(cout))) 

297 for out_op, out in zip(operator.outputs, self.onnxrt_.obj.graph.output): 

298 var = _var_as_dict(out) 

299 if var['type']['kind'] != 'tensor': 

300 raise NotImplementedError( # pragma: no cover 

301 "Noy yet implemented for output:\n{}".format(out)) 

302 shape = var['type']['shape'] 

303 if shape[0] == 0: 

304 shape = (None,) + tuple(shape[1:]) 

305 elem = var['type']['elem'] 

306 if elem == 'float': 

307 out_op.type = FloatTensorType(shape=shape) 

308 elif elem == 'int64': 

309 out_op.type = Int64TensorType(shape=shape) 

310 elif elem == 'double': 

311 out_op.type = DoubleTensorType(shape=shape) 

312 else: 

313 raise NotImplementedError( # pragma: no cover 

314 "Not yet implemented for elem_type:\n{}".format(elem)) 

315 return shape_calculator 

316 

317 def onnx_converter(self): 

318 """ 

319 Returns a converter for this model. 

320 If not overloaded, it fetches the converter 

321 mapped to the first *scikit-learn* parent 

322 it can find. 

323 """ 

324 def converter(scope, operator, container, onnx_model=None): 

325 op = operator.raw_operator 

326 onx = onnx_model or op.onnxrt_.obj 

327 add_onnx_graph(scope, operator, container, onx) 

328 

329 return converter 

330 

331 @property 

332 def opsets(self): 

333 """ 

334 Returns the opsets as dictionary ``{domain: opset}``. 

335 """ 

336 if hasattr(self, 'onnxrt_'): 

337 model = self.onnxrt_.obj 

338 else: 

339 model = load_onnx_model(self.onnx_bytes) 

340 res = {} 

341 for oimp in model.opset_import: 

342 res[oimp.domain] = oimp.version 

343 return res