Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- encoding: utf-8 -*- 

2""" 

3@file 

4@brief Shortcut to *ops_onnxruntime*. 

5""" 

6import numpy 

7import onnx.defs 

8from onnx.helper import make_tensor 

9from onnx.onnx_cpp2py_export.shape_inference import InferenceError # pylint: disable=E0401,E0611 

10from skl2onnx.common.data_types import ( 

11 DictionaryType, FloatTensorType, Int64TensorType, StringTensorType) 

12import skl2onnx.algebra.onnx_ops as alg 

13try: 

14 import skl2onnx.algebra.custom_ops as alg2 

15except ImportError: # pragma: no cover 

16 # older version of skl2onnx 

17 alg2 = alg 

18from ...tools.ort_wrapper import ( 

19 InferenceSession, SessionOptions, RunOptions, 

20 GraphOptimizationLevel, OrtInvalidArgument, 

21 OrtNotImplemented, OrtInvalidGraph, OrtFail) 

22from ...onnx_tools.onnx2py_helper import guess_proto_dtype 

23from ...onnx_tools.optim.graph_schema_helper import ( 

24 get_defined_inputs, get_defined_outputs, proto2vars) 

25 

26 

27_schemas = { 

28 schema.name: schema for schema in onnx.defs.get_all_schemas_with_history()} 

29 

30 

31class OpRunOnnxRuntime: 

32 """ 

33 Unique operator which calls :epkg:`onnxruntime` 

34 to compute predictions for one operator. 

35 """ 

36 

37 def __init__(self, onnx_node, desc=None, variables=None, 

38 dtype=None, **options): 

39 """ 

40 @param onnx_node :epkg:`onnx` node 

41 @param desc internal representation 

42 @param variables registered variables created by previous operators 

43 @param dtype float computation type 

44 @param options runtime options 

45 """ 

46 self._provider = 'onnxruntime' 

47 self.onnx_node = onnx_node 

48 self.desc = desc 

49 self._schema = _schemas.get(onnx_node.op_type, None) 

50 if desc is not None: 

51 if 'atts' in desc: 

52 for a, b in desc['atts'].items(): 

53 if not isinstance(b, dict) or 'value' not in b: 

54 raise ValueError( # pragma: no cover 

55 "Unexpected value {}.".format(b)) 

56 options[a] = b['value'] 

57 

58 self.options = options 

59 self.dtype = dtype 

60 self._init(variables) 

61 

62 def _name_mapping(self, inputs): 

63 mapping = {} 

64 new_inputs = [] 

65 for name in inputs: 

66 if name in mapping: 

67 i = 0 

68 new_name = "{}_{}".format(name, i) 

69 while new_name in mapping: 

70 i += 1 # pragma: no cover 

71 new_name = "{}_{}".format(name, i) # pragma: no cover 

72 mapping[new_name] = name 

73 new_inputs.append(new_name) 

74 else: 

75 new_inputs.append(name) 

76 mapping[name] = name 

77 return mapping, new_inputs 

78 

79 def _guess_proto_type(self, dtype): 

80 return guess_proto_dtype(dtype) 

81 

82 def _init(self, variables=None): 

83 """ 

84 Initializes the node. 

85 

86 :param variables: registered variables created by previous operators 

87 

88 The current implementation for operator *Scan* 

89 only works for matrices. 

90 """ 

91 custom_nodes = self.options.get('nodes', None) 

92 if (custom_nodes is not None and 

93 self.onnx_node.op_type in custom_nodes): 

94 self.alg_class = custom_nodes[self.onnx_node.op_type] 

95 else: 

96 try: 

97 self.alg_class = getattr(alg2, 'Onnx' + self.onnx_node.op_type) 

98 except AttributeError: 

99 self.alg_class = getattr(alg, 'Onnx' + self.onnx_node.op_type) 

100 

101 inputs = list(self.onnx_node.input) 

102 self.mapping, self.inputs = self._name_mapping(inputs) 

103 self.outputs = list(self.onnx_node.output) 

104 

105 options = self.options.copy() 

106 options.pop('nodes', None) 

107 target_opset = options.pop('target_opset', None) 

108 domain = options.pop('domain', None) 

109 disable_optimisation = options.pop('disable_optimisation', False) 

110 session_options = options.pop('session_options', False) 

111 ir_version = options.pop('ir_version', None) 

112 

113 if domain == '' and target_opset < 9: 

114 # target_opset should be >= 9 not {} for main domain. 

115 # We assume it was the case when the graph was created. 

116 pass 

117 

118 if self.onnx_node.op_type == 'ZipMap': 

119 self.inst_ = self.alg_class(*self.inputs, output_names=self.outputs, 

120 op_version=target_opset, **options) 

121 inputs = get_defined_inputs( 

122 self.inputs, variables, dtype=self.dtype) 

123 name = (self.outputs[0] if len(self.outputs) == 1 

124 else self.inst_.expected_outputs[0][0]) 

125 otype = (Int64TensorType if 'classlabels_int64s' in options 

126 else StringTensorType) 

127 outvar = [(name, DictionaryType(otype([1]), FloatTensorType([1])))] 

128 self.onnx_ = self.inst_.to_onnx(inputs, outputs=outvar) 

129 forced = True 

130 elif self.onnx_node.op_type == 'ConstantOfShape': 

131 for k in options: 

132 v = options[k] 

133 if isinstance(v, numpy.ndarray): 

134 options[k] = make_tensor( 

135 k, self._guess_proto_type(v.dtype), 

136 v.shape, v.tolist()) 

137 

138 self.inst_ = self.alg_class(*self.inputs, output_names=self.outputs, 

139 op_version=target_opset, **options) 

140 inputs = get_defined_inputs( 

141 self.inputs, variables, dtype=self.dtype) 

142 try: 

143 self.onnx_ = self.inst_.to_onnx(inputs, target_opset=target_opset, 

144 domain=domain) 

145 if "dim_value: 0" in str(self.onnx_): 

146 raise RuntimeError( # pragma: no cover 

147 "Probable issue as one dimension is null.\n--\n{}".format( 

148 self.onnx_)) 

149 except AttributeError as e: # pragma: no cover 

150 # older version of skl2onnx 

151 self.onnx_ = self.inst_.to_onnx(inputs) 

152 if "dim_value: 0" in str(self.onnx_): 

153 raise RuntimeError( 

154 "Probable issue as one dimension is null.\n--\n{}".format( 

155 self.onnx_)) from e 

156 forced = False 

157 elif self.onnx_node.op_type == 'Scan': 

158 self.inst_ = self.alg_class( 

159 *self.inputs, output_names=self.outputs, 

160 op_version=target_opset, **options) 

161 inputs = get_defined_inputs( 

162 self.inputs, variables, dtype=self.dtype) 

163 outputs = get_defined_outputs( 

164 self.outputs, self.onnx_node, inputs, variables, 

165 dtype=self.dtype) 

166 inputs = [(name, cl.__class__([None, None])) 

167 for (name, cl) in inputs] 

168 outputs = [(name, cl.__class__([None, None])) 

169 for (name, cl) in outputs] 

170 self.onnx_ = self.inst_.to_onnx(inputs, outputs=outputs, 

171 target_opset=target_opset, 

172 domain=domain) 

173 if "dim_value: 0" in str(self.onnx_): 

174 raise RuntimeError( # pragma: no cover 

175 "Probable issue as one dimension is null.\n--\n{}".format( 

176 self.onnx_)) 

177 forced = True 

178 else: 

179 self.inst_ = self.alg_class(*self.inputs, output_names=self.outputs, 

180 op_version=target_opset, domain=domain, 

181 **options) 

182 inputs = get_defined_inputs( 

183 self.inputs, variables, dtype=self.dtype, 

184 schema=self.alg_class.expected_inputs) 

185 

186 try: 

187 self.onnx_ = self.inst_.to_onnx( 

188 inputs, target_opset=target_opset, domain=domain) 

189 if "dim_value: 0" in str(self.onnx_): 

190 raise RuntimeError( # pragma: no cover 

191 "Probable issue as one dimension is null.\n--\n{}\n---\n{}".format( 

192 self.onnx_, inputs)) 

193 forced = False 

194 except (RuntimeError, ValueError, InferenceError) as eo: 

195 # Let's try again by forcing output types. 

196 forced = True 

197 outputs = get_defined_outputs( 

198 self.outputs, self.onnx_node, inputs, variables, 

199 dtype=self.dtype, schema=self.alg_class.expected_outputs, 

200 schema_inputs=self.alg_class.expected_inputs) 

201 try: 

202 self.onnx_ = self.inst_.to_onnx(inputs, outputs=outputs, 

203 target_opset=target_opset, 

204 domain=domain) 

205 except NotImplementedError as e: 

206 raise NotImplementedError( 

207 "Unable to instantiate node {} inputs={} " 

208 "self.inputs={} outputs={} variables={} " 

209 "dtype={} e={} eo={}".format( 

210 self.alg_class, inputs, self.inputs, 

211 outputs, variables, self.dtype, e, eo)) from e 

212 if "dim_value: 0" in str(self.onnx_): 

213 raise RuntimeError( # pragma: no cover 

214 "Probable issue as one dimension is null.\n--\n{}".format( 

215 self.onnx_)) from e 

216 

217 if len(self.onnx_.graph.output) != len(self.outputs): # pragma: no cover 

218 # Something is wrong, falls back to default plan. 

219 forced = True 

220 outputs = get_defined_outputs( 

221 self.outputs, self.onnx_node, inputs, variables, 

222 dtype=self.dtype, schema=self.alg_class.expected_outputs) 

223 self.onnx_ = self.inst_.to_onnx(inputs, outputs=outputs, 

224 target_opset=target_opset, 

225 domain=domain) 

226 if "dim_value: 0" in str(self.onnx_): 

227 raise RuntimeError( # pragma: no cover 

228 "Probable issue as one dimension is null.\n--\n{}".format( 

229 self.onnx_)) 

230 else: 

231 lo = list(self.onnx_.graph.output) 

232 outputs = proto2vars(lo) 

233 

234 sess_options = session_options or SessionOptions() 

235 self.run_options = RunOptions() 

236 

237 if session_options is None: 

238 try: 

239 sess_options.session_log_severity_level = 3 

240 # sess_options.sessions_log_verbosity_level = 0 

241 except AttributeError: 

242 # onnxruntime not recent enough. 

243 pass 

244 try: 

245 self.run_options.run_log_severity_level = 3 

246 # self.run_options.run_log_verbosity_level = 0 

247 except AttributeError: 

248 # onnxruntime not recent enough. 

249 pass 

250 if disable_optimisation: 

251 sess_options.graph_optimization_level = ( # pragma: no cover 

252 GraphOptimizationLevel.ORT_DISABLE_ALL) 

253 elif disable_optimisation: 

254 raise RuntimeError( # pragma: no cover 

255 "session_options and disable_optimisation cannot be defined " 

256 "at the same time.") 

257 

258 if ir_version is not None: 

259 self.onnx_.ir_version = ir_version 

260 try: 

261 self.sess_ = InferenceSession( 

262 self.onnx_.SerializeToString(), sess_options=sess_options) 

263 except (RuntimeError, OrtNotImplemented, OrtInvalidGraph, OrtFail) as e: 

264 raise RuntimeError( 

265 "Unable to load node '{}' (output type was {}) inputs={} " 

266 "self.inputs={} self.onnx_node.input={} " 

267 "variables={} mapping={} " 

268 "expected_inputs={}\n{}".format( 

269 self.onnx_node.op_type, 

270 "guessed" if forced else "inferred", 

271 inputs, self.inputs, self.onnx_node.input, 

272 variables, self.mapping, 

273 self.alg_class.expected_inputs, 

274 self.onnx_)) from e 

275 self.typed_outputs_ = outputs 

276 

277 def run(self, *args, **kwargs): 

278 """ 

279 Should be overwritten. 

280 """ 

281 inputs = {name: val for name, val in zip(self.inputs, args)} 

282 

283 try: 

284 res = self.sess_.run(None, inputs, self.run_options) 

285 except (RuntimeError, OrtInvalidArgument) as e: # pragma: no cover 

286 dtypes = {k: v.dtype for k, v in inputs.items()} 

287 shapes = {k: v.shape for k, v in inputs.items()} 

288 exp = [_.name for _ in self.sess_.get_inputs()] 

289 exp_types = [_.type for _ in self.sess_.get_inputs()] 

290 raise RuntimeError( 

291 "Predictions failed. List of inputs: {}, class={}" 

292 "\ndtypes={}\nshapes={}\nexpected={}\nexpected={}\n" 

293 "exception={}\n--ONNX--\n{}".format( 

294 list(sorted(inputs)), self.alg_class, dtypes, 

295 shapes, exp, exp_types, e, self.onnx_)) from e 

296 return tuple(res) 

297 

298 def need_context(self): 

299 """ 

300 Tells the runtime if this node needs the context 

301 (all the results produced so far) as it may silently access 

302 one of them (operator Loop). 

303 The default answer is `False`. 

304 """ 

305 return False