Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Inspired from :epkg:`sklearn-onnx`, handles two backends. 

4""" 

5import os 

6import pickle 

7import numpy 

8from numpy.testing import assert_array_almost_equal, assert_array_equal 

9from scipy.sparse.csr import csr_matrix 

10import pandas 

11from ...onnxrt.ops_cpu.op_zipmap import ArrayZipMapDictionary 

12 

13 

14class ExpectedAssertionError(Exception): 

15 """ 

16 Expected failure. 

17 """ 

18 pass 

19 

20 

21class OnnxBackendAssertionError(AssertionError): 

22 """ 

23 Expected failure. 

24 """ 

25 pass 

26 

27 

28class OnnxBackendMissingNewOnnxOperatorException(OnnxBackendAssertionError): 

29 """ 

30 Raised when :epkg:`onnxruntime` or :epkg:`mlprodict` 

31 does not implement a new operator 

32 defined in the latest onnx. 

33 """ 

34 pass 

35 

36 

37class OnnxRuntimeMissingNewOnnxOperatorException(OnnxBackendAssertionError): 

38 """ 

39 Raised when a new operator was added but cannot be found. 

40 """ 

41 pass 

42 

43 

44def evaluate_condition(backend, condition): 

45 """ 

46 Evaluates a condition such as 

47 ``StrictVersion(onnxruntime.__version__) <= StrictVersion('0.1.3')`` 

48 """ 

49 if backend == "onnxruntime": # pragma: no cover 

50 import onnxruntime # pylint: disable=W0611 

51 return eval(condition) # pylint: disable=W0123 

52 raise NotImplementedError( # pragma no cover 

53 "Not implemented for backend '{0}' and " 

54 "condition '{1}'.".format(backend, condition)) 

55 

56 

57def is_backend_enabled(backend): 

58 """ 

59 Tells if a backend is enabled. 

60 Raises an exception if backend != 'onnxruntime'. 

61 Unit tests only test models against this backend. 

62 """ 

63 if backend == "onnxruntime": 

64 try: 

65 import onnxruntime # pylint: disable=W0611 

66 return True 

67 except ImportError: # pragma no cover 

68 return False 

69 if backend == "python": 

70 return True 

71 raise NotImplementedError( # pragma no cover 

72 "Not implemented for backend '{0}'".format(backend)) 

73 

74 

75def load_data_and_model(items_as_dict, **context): 

76 """ 

77 Loads every file in a dictionary {key: filename}. 

78 The extension is either *pkl* and *onnx* and determines 

79 how it it loaded. If the value is not a string, 

80 the function assumes it was already loaded. 

81 """ 

82 res = {} 

83 for k, v in items_as_dict.items(): 

84 if isinstance(v, str): 

85 if os.path.splitext(v)[-1] == ".pkl": 

86 with open(v, "rb") as f: # pragma: no cover 

87 try: 

88 bin = pickle.load(f) 

89 except ImportError as e: 

90 if '.model.' in v: 

91 continue 

92 raise ImportError( # pylint: disable=W0707 

93 "Unable to load '{0}' due to {1}".format(v, e)) 

94 res[k] = bin 

95 else: 

96 res[k] = v 

97 else: 

98 res[k] = v 

99 return res 

100 

101 

102def extract_options(name): 

103 """ 

104 Extracts comparison option from filename. 

105 As example, ``Binarizer-SkipDim1`` means 

106 options *SkipDim1* is enabled. 

107 ``(1, 2)`` and ``(2,)`` are considered equal. 

108 Available options: see :func:`dump_data_and_model`. 

109 """ 

110 opts = name.replace("\\", "/").split("/")[-1].split('.')[0].split('-') 

111 if len(opts) == 1: 

112 return {} 

113 res = {} 

114 for opt in opts[1:]: 

115 if opt in ("SkipDim1", "OneOff", "NoProb", "NoProbOpp", 

116 "Dec4", "Dec3", "Dec2", 'Svm', 

117 'Out0', 'Reshape', 'SklCol', 'DF', 'OneOffArray'): 

118 res[opt] = True 

119 else: 

120 raise NameError("Unable to parse option '{}'".format( 

121 opts[1:])) # pragma no cover 

122 return res 

123 

124 

125def compare_outputs(expected, output, verbose=False, **kwargs): 

126 """ 

127 Compares expected values and output. 

128 Returns None if no error, an exception message otherwise. 

129 """ 

130 SkipDim1 = kwargs.pop("SkipDim1", False) 

131 NoProb = kwargs.pop("NoProb", False) 

132 NoProbOpp = kwargs.pop("NoProbOpp", False) 

133 Dec4 = kwargs.pop("Dec4", False) 

134 Dec3 = kwargs.pop("Dec3", False) 

135 Dec2 = kwargs.pop("Dec2", False) 

136 Disc = kwargs.pop("Disc", False) 

137 Mism = kwargs.pop("Mism", False) 

138 

139 if Dec4: 

140 kwargs["decimal"] = min(kwargs["decimal"], 4) 

141 if Dec3: 

142 kwargs["decimal"] = min(kwargs["decimal"], 3) 

143 if Dec2: 

144 kwargs["decimal"] = min(kwargs["decimal"], 2) # pragma: no cover 

145 if isinstance(expected, numpy.ndarray) and isinstance( 

146 output, numpy.ndarray): 

147 if SkipDim1: 

148 # Arrays like (2, 1, 2, 3) becomes (2, 2, 3) 

149 # as one dimension is useless. 

150 expected = expected.reshape( 

151 tuple([d for d in expected.shape if d > 1])) 

152 output = output.reshape( 

153 tuple([d for d in expected.shape if d > 1])) 

154 if NoProb or NoProbOpp: 

155 # One vector is (N,) with scores, negative for class 0 

156 # positive for class 1 

157 # The other vector is (N, 2) score in two columns. 

158 if len(output.shape) == 2 and output.shape[1] == 2 and len( 

159 expected.shape) == 1: 

160 output = output[:, 1] 

161 if NoProbOpp: 

162 output = -output 

163 elif len(output.shape) == 1 and len(expected.shape) == 1: 

164 pass 

165 elif len(expected.shape) == 1 and len(output.shape) == 2 and \ 

166 expected.shape[0] == output.shape[0] and \ 

167 output.shape[1] == 1: 

168 output = output[:, 0] 

169 if NoProbOpp: 

170 output = -output 

171 elif expected.shape != output.shape: 

172 raise NotImplementedError("Shape mismatch: {0} != {1}".format( # pragma no cover 

173 expected.shape, output.shape)) 

174 if len(expected.shape) == 1 and len( 

175 output.shape) == 2 and output.shape[1] == 1: 

176 output = output.ravel() 

177 if len(output.shape) == 3 and output.shape[0] == 1 and len( 

178 expected.shape) == 2: 

179 output = output.reshape(output.shape[1:]) 

180 if expected.dtype in (numpy.str_, numpy.dtype("<U1"), 

181 numpy.dtype("<U3")): 

182 try: 

183 assert_array_equal(expected, output, verbose=verbose) 

184 except Exception as e: # pylint: disable=W0703 

185 if Disc: # pragma no cover 

186 # Bug to be fixed later. 

187 return ExpectedAssertionError(str(e)) 

188 else: # pragma no cover 

189 return OnnxBackendAssertionError(str(e)) 

190 else: 

191 try: 

192 assert_array_almost_equal(expected, 

193 output, 

194 verbose=verbose, 

195 **kwargs) 

196 except (RuntimeError, AssertionError) as e: # pragma no cover 

197 longer = "\n--EXPECTED--\n{0}\n--OUTPUT--\n{1}".format( 

198 expected, output) if verbose else "" 

199 expected_ = numpy.asarray(expected).ravel() 

200 output_ = numpy.asarray(output).ravel() 

201 if len(expected_) == len(output_): 

202 if numpy.issubdtype(expected_.dtype, numpy.floating): 

203 diff = numpy.abs(expected_ - output_).max() 

204 else: 

205 diff = max((1 if ci != cj else 0) 

206 for ci, cj in zip(expected_, output_)) 

207 if diff == 0: 

208 return None 

209 elif Mism: 

210 return ExpectedAssertionError( 

211 "dimension mismatch={0}, {1}\n{2}{3}".format( 

212 expected.shape, output.shape, e, longer)) 

213 else: 

214 return OnnxBackendAssertionError( 

215 "dimension mismatch={0}, {1}\n{2}{3}".format( 

216 expected.shape, output.shape, e, longer)) 

217 if Disc: 

218 # Bug to be fixed later. 

219 return ExpectedAssertionError( 

220 "max-diff={0}\n--expected--output--\n{1}{2}".format( 

221 diff, e, longer)) 

222 return OnnxBackendAssertionError( 

223 "max-diff={0}\n--expected--output--\n{1}{2}".format( 

224 diff, e, longer)) 

225 else: 

226 return OnnxBackendAssertionError( # pragma: no cover 

227 "Unexpected types {0} != {1}".format( 

228 type(expected), type(output))) 

229 return None 

230 

231 

232def _post_process_output(res): 

233 """ 

234 Applies post processings before running the comparison 

235 such as changing type from list to arrays. 

236 """ 

237 if isinstance(res, list): 

238 if len(res) == 0: 

239 return res 

240 if len(res) == 1: 

241 return _post_process_output(res[0]) 

242 if isinstance(res[0], numpy.ndarray): 

243 return numpy.array(res) 

244 if isinstance(res[0], dict): 

245 return pandas.DataFrame(res).values 

246 ls = [len(r) for r in res] 

247 mi = min(ls) 

248 if mi != max(ls): 

249 raise NotImplementedError( # pragma no cover 

250 "Unable to postprocess various number of " 

251 "outputs in [{0}, {1}]" 

252 .format(min(ls), max(ls))) 

253 if mi > 1: 

254 output = [] 

255 for i in range(mi): 

256 output.append(_post_process_output([r[i] for r in res])) 

257 return output 

258 if isinstance(res[0], list): 

259 # list of lists 

260 if isinstance(res[0][0], list): 

261 return numpy.array(res) 

262 if len(res[0]) == 1 and isinstance(res[0][0], dict): 

263 return _post_process_output([r[0] for r in res]) 

264 if len(res) == 1: 

265 return res 

266 if len(res[0]) != 1: 

267 raise NotImplementedError( # pragma no cover 

268 "Not conversion implemented for {0}".format(res)) 

269 st = [r[0] for r in res] 

270 return numpy.vstack(st) 

271 return res 

272 return res 

273 

274 

275def _create_column(values, dtype): 

276 "Creates a column from values with dtype" 

277 if str(dtype) == "tensor(int64)": 

278 return numpy.array(values, dtype=numpy.int64) 

279 if str(dtype) == "tensor(float)": 

280 return numpy.array(values, dtype=numpy.float32) 

281 if str(dtype) in ("tensor(double)", "tensor(float64)"): 

282 return numpy.array(values, dtype=numpy.float64) 

283 if str(dtype) in ("tensor(string)", "tensor(str)"): 

284 return numpy.array(values, dtype=numpy.str_) 

285 raise OnnxBackendAssertionError( 

286 "Unable to create one column from dtype '{0}'".format(dtype)) 

287 

288 

289def _compare_expected(expected, output, sess, onnx_model, 

290 decimal=5, verbose=False, classes=None, 

291 **kwargs): 

292 """ 

293 Compares the expected output against the runtime outputs. 

294 This is specific to :epkg:`onnxruntime` or :epkg:`mlprodict`. 

295 """ 

296 tested = 0 

297 if isinstance(expected, list): 

298 if isinstance(output, list): 

299 if 'Out0' in kwargs: 

300 expected = expected[:1] 

301 output = output[:1] 

302 del kwargs['Out0'] 

303 if 'Reshape' in kwargs: 

304 del kwargs['Reshape'] 

305 output = numpy.hstack(output).ravel() 

306 output = output.reshape( 

307 (len(expected), len(output.ravel()) // len(expected))) 

308 if len(expected) != len(output): 

309 raise OnnxBackendAssertionError( # pragma no cover 

310 "Unexpected number of outputs '{0}', expected={1}, got={2}" 

311 .format(onnx_model, len(expected), len(output))) 

312 for exp, out in zip(expected, output): 

313 _compare_expected(exp, out, sess, onnx_model, decimal=5, verbose=verbose, 

314 classes=classes, **kwargs) 

315 tested += 1 

316 else: 

317 raise OnnxBackendAssertionError( # pragma no cover 

318 "Type mismatch for '{0}', output type is {1}".format( 

319 onnx_model, type(output))) 

320 elif isinstance(expected, dict): 

321 if not isinstance(output, dict): 

322 raise OnnxBackendAssertionError( # pragma no cover 

323 "Type mismatch for '{0}'".format(onnx_model)) 

324 for k, v in output.items(): 

325 if k not in expected: 

326 continue 

327 msg = compare_outputs( 

328 expected[k], v, decimal=decimal, verbose=verbose, **kwargs) 

329 if msg: 

330 raise OnnxBackendAssertionError( # pragma no cover 

331 "Unexpected output '{0}' in model '{1}'\n{2}".format( 

332 k, onnx_model, msg)) 

333 tested += 1 

334 elif isinstance(expected, numpy.ndarray): 

335 if isinstance(output, list): 

336 if expected.shape[0] == len(output) and isinstance( 

337 output[0], dict): 

338 if isinstance(output, ArrayZipMapDictionary): 

339 output = pandas.DataFrame(list(output)) 

340 else: 

341 output = pandas.DataFrame(output) 

342 output = output[list(sorted(output.columns))] 

343 output = output.values 

344 if isinstance(output, (dict, list)): 

345 if len(output) != 1: # pragma: no cover 

346 ex = str(output) 

347 if len(ex) > 170: 

348 ex = ex[:170] + "..." 

349 raise OnnxBackendAssertionError( 

350 "More than one output when 1 is expected " 

351 "for onnx '{0}'\n{1}" 

352 .format(onnx_model, ex)) 

353 output = output[-1] 

354 if not isinstance(output, numpy.ndarray): 

355 raise OnnxBackendAssertionError( # pragma no cover 

356 "output must be an array for onnx '{0}' not {1}".format( 

357 onnx_model, type(output))) 

358 if (classes is not None and ( 

359 expected.dtype == numpy.str_ or expected.dtype.char == 'U')): 

360 try: 

361 output = numpy.array([classes[cl] for cl in output]) 

362 except IndexError as e: # pragma no cover 

363 raise RuntimeError('Unable to handle\n{}\n{}\n{}'.format( 

364 expected, output, classes)) from e 

365 msg = compare_outputs( 

366 expected, output, decimal=decimal, verbose=verbose, **kwargs) 

367 if isinstance(msg, ExpectedAssertionError): 

368 raise msg # pylint: disable=E0702 

369 if msg: 

370 raise OnnxBackendAssertionError( # pragma no cover 

371 "Unexpected output in model '{0}'\n{1}".format(onnx_model, msg)) 

372 tested += 1 

373 else: 

374 if isinstance(expected, csr_matrix): 

375 # DictVectorizer 

376 one_array = numpy.array(output) 

377 dense = numpy.asarray(expected.todense()) 

378 msg = compare_outputs(dense, one_array, decimal=decimal, 

379 verbose=verbose, **kwargs) 

380 if msg: 

381 raise OnnxBackendAssertionError( # pragma no cover 

382 "Unexpected output in model '{0}'\n{1}".format(onnx_model, msg)) 

383 tested += 1 

384 else: 

385 raise OnnxBackendAssertionError( # pragma no cover 

386 "Unexpected type for expected output ({1}) and onnx '{0}'". 

387 format(onnx_model, type(expected))) 

388 if tested == 0: 

389 raise OnnxBackendAssertionError( # pragma no cover 

390 "No test for onnx '{0}'".format(onnx_model))