Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Inspired from skl2onnx, handles two backends. 

4""" 

5import numpy 

6import onnx 

7import pandas 

8from ...tools.ort_wrapper import OrtInvalidArgument 

9from .utils_backend_common import ( 

10 load_data_and_model, extract_options, 

11 ExpectedAssertionError, OnnxBackendAssertionError, 

12 OnnxRuntimeMissingNewOnnxOperatorException, 

13 _compare_expected, _create_column) 

14 

15 

16def compare_runtime_session( # pylint: disable=R0912 

17 cls_session, test, decimal=5, options=None, 

18 verbose=False, context=None, comparable_outputs=None, 

19 intermediate_steps=False, classes=None, 

20 disable_optimisation=False): 

21 """ 

22 The function compares the expected output (computed with 

23 the model before being converted to ONNX) and the ONNX output 

24 produced with module :epkg:`onnxruntime` or :epkg:`mlprodict`. 

25 

26 :param cls_session: inference session instance (like @see cl OnnxInference) 

27 :param test: dictionary with the following keys: 

28 - *onnx*: onnx model (filename or object) 

29 - *expected*: expected output (filename pkl or object) 

30 - *data*: input data (filename pkl or object) 

31 :param decimal: precision of the comparison 

32 :param options: comparison options 

33 :param context: specifies custom operators 

34 :param verbose: in case of error, the function may print 

35 more information on the standard output 

36 :param comparable_outputs: compare only these outputs 

37 :param intermediate_steps: displays intermediate steps 

38 in case of an error 

39 :param classes: classes names (if option 'nocl' is used) 

40 :param disable_optimisation: disable optimisation the runtime may do 

41 :return: tuple (outut, lambda function to run the predictions) 

42 

43 The function does not return anything but raises an error 

44 if the comparison failed. 

45 """ 

46 lambda_onnx = None 

47 if context is None: 

48 context = {} 

49 load = load_data_and_model(test, **context) 

50 if verbose: # pragma no cover 

51 print("[compare_runtime] test '{}' loaded".format(test['onnx'])) 

52 

53 onx = test['onnx'] 

54 

55 if options is None: 

56 if isinstance(onx, str): 

57 options = extract_options(onx) 

58 else: 

59 options = {} 

60 elif options is None: 

61 options = {} 

62 elif not isinstance(options, dict): 

63 raise TypeError( # pragma no cover 

64 "options must be a dictionary.") 

65 

66 if verbose: # pragma no cover 

67 print("[compare_runtime] InferenceSession('{}')".format(onx)) 

68 

69 runtime_options = dict(disable_optimisation=disable_optimisation) 

70 try: 

71 sess = cls_session(onx, runtime_options=runtime_options) 

72 except TypeError as et: # pragma: no cover 

73 raise TypeError( # pylint: disable=W0707 

74 "Wrong signature for '{}' ({}).".format(cls_session.__name__, et)) 

75 except ExpectedAssertionError as expe: # pragma no cover 

76 raise expe 

77 except Exception as e: # pylint: disable=W0703 

78 if "CannotLoad" in options: # pragma no cover 

79 raise ExpectedAssertionError( # pylint: disable=W0707 

80 "Unable to load onnx '{0}' due to\n{1}".format(onx, e)) 

81 else: # pragma no cover 

82 if verbose: # pragma no cover 

83 model = onnx.load(onx) 

84 smodel = "\nJSON ONNX\n" + str(model) 

85 else: 

86 smodel = "" 

87 if ("NOT_IMPLEMENTED : Could not find an implementation " 

88 "for the node" in str(e)): 

89 # onnxruntime does not implement a specific node yet. 

90 raise OnnxRuntimeMissingNewOnnxOperatorException( # pylint: disable=W0707 

91 "{3} does not implement a new operator " 

92 "'{0}'\n{1}\nONNX\n{2}".format( 

93 onx, e, smodel, cls_session)) 

94 if "NOT_IMPLEMENTED : Failed to find kernel" in str(e): 

95 # onnxruntime does not implement a specific node yet 

96 # in the kernel included in onnxruntime. 

97 raise OnnxBackendAssertionError( # pylint: disable=W0707 

98 "{3} misses a kernel for operator " 

99 "'{0}'\n{1}\nONNX\n{2}".format( 

100 onx, e, smodel, cls_session)) 

101 raise OnnxBackendAssertionError( # pylint: disable=W0707 

102 "Unable to load onnx '{0}'\nONNX\n{1}\n{2}".format( 

103 onx, smodel, e)) 

104 

105 input = load["data"] 

106 DF = options.pop('DF', False) 

107 if DF: 

108 inputs = {c: input[c].values for c in input.columns} 

109 for k in inputs: 

110 if inputs[k].dtype == numpy.float64: 

111 inputs[k] = inputs[k].astype(numpy.float32) 

112 inputs[k] = inputs[k].reshape((inputs[k].shape[0], 1)) 

113 else: 

114 if isinstance(input, dict): 

115 inputs = input 

116 elif isinstance(input, (list, numpy.ndarray, pandas.DataFrame)): 

117 inp = sess.get_inputs() 

118 outs = sess.get_outputs() 

119 if len(outs) == 0: 

120 raise OnnxBackendAssertionError( # pragma: no cover 

121 "Wrong number of outputs, onnx='{2}'".format(onx)) 

122 if len(inp) == len(input): 

123 inputs = {i.name: v for i, v in zip(inp, input)} 

124 elif len(inp) == 1: 

125 inputs = {inp[0].name: input} 

126 elif isinstance(input, numpy.ndarray): 

127 shape = sum(i.shape[1] if len(i.shape) == 2 else i.shape[0] 

128 for i in inp) 

129 if shape == input.shape[1]: 

130 inputs = {n.name: input[:, i] for i, n in enumerate(inp)} 

131 else: 

132 raise OnnxBackendAssertionError( # pragma: no cover 

133 "Wrong number of inputs onnx {0} != " 

134 "original shape {1}, onnx='{2}'" 

135 .format(len(inp), input.shape, onx)) 

136 elif isinstance(input, list): 

137 try: 

138 array_input = numpy.array(input) 

139 except Exception: # pragma no cover 

140 raise OnnxBackendAssertionError( # pylint: disable=W0707 

141 "Wrong number of inputs onnx {0} != " 

142 "original {1}, onnx='{2}'" 

143 .format(len(inp), len(input), onx)) 

144 shape = sum(i.shape[1] for i in inp) 

145 if shape == array_input.shape[1]: 

146 inputs = {} 

147 c = 0 

148 for i, n in enumerate(inp): 

149 d = c + n.shape[1] 

150 inputs[n.name] = _create_column( 

151 [row[c:d] for row in input], n.type) 

152 c = d 

153 else: 

154 raise OnnxBackendAssertionError( # pragma no cover 

155 "Wrong number of inputs onnx {0} != " 

156 "original shape {1}, onnx='{2}'*" 

157 .format(len(inp), array_input.shape, onx)) 

158 elif isinstance(input, pandas.DataFrame): 

159 try: 

160 array_input = numpy.array(input) 

161 except Exception: # pragma no cover 

162 raise OnnxBackendAssertionError( # pylint: disable=W0707 

163 "Wrong number of inputs onnx {0} != " 

164 "original {1}, onnx='{2}'" 

165 .format(len(inp), len(input), onx)) 

166 shape = sum(i.shape[1] for i in inp) 

167 if shape == array_input.shape[1]: 

168 inputs = {} 

169 c = 0 

170 for i, n in enumerate(inp): 

171 d = c + n.shape[1] 

172 inputs[n.name] = _create_column( 

173 input.iloc[:, c:d], n.type) 

174 c = d 

175 else: 

176 raise OnnxBackendAssertionError( # pragma no cover 

177 "Wrong number of inputs onnx {0}={1} columns != " 

178 "original shape {2}, onnx='{3}'*" 

179 .format(len(inp), shape, array_input.shape, onx)) 

180 else: 

181 raise OnnxBackendAssertionError( # pragma no cover 

182 "Wrong type of inputs onnx {0}, onnx='{1}'".format( 

183 type(input), onx)) 

184 else: 

185 raise OnnxBackendAssertionError( # pragma no cover 

186 "Dict or list is expected, not {0}".format(type(input))) 

187 

188 for k in inputs: 

189 if isinstance(inputs[k], list): 

190 inputs[k] = numpy.array(inputs[k]) 

191 

192 options.pop('SklCol', False) # unused here but in dump_data_and_model 

193 

194 if verbose: # pragma no cover 

195 print("[compare_runtime] type(inputs)={} len={} names={}".format( 

196 type(input), len(inputs), list(sorted(inputs)))) 

197 if verbose: # pragma no cover 

198 if intermediate_steps: 

199 run_options = {'verbose': 3, 'fLOG': print} 

200 else: 

201 run_options = {'verbose': 2, 'fLOG': print} 

202 else: 

203 run_options = {} 

204 try: 

205 try: 

206 output = sess.run(None, inputs, **run_options) 

207 except TypeError: # pragma no cover 

208 output = sess.run(None, inputs) 

209 lambda_onnx = lambda: sess.run(None, inputs) # noqa 

210 if verbose: # pragma no cover 

211 import pprint 

212 pprint.pprint(output) 

213 except ExpectedAssertionError as expe: # pragma no cover 

214 raise expe 

215 except (RuntimeError, OrtInvalidArgument) as e: # pragma no cover 

216 if intermediate_steps: 

217 sess.run(None, inputs, verbose=3, fLOG=print) 

218 if "-Fail" in onx: 

219 raise ExpectedAssertionError( # pylint: disable=W0707 

220 "{1} cannot compute the prediction for '{0}'". 

221 format(onx, cls_session)) 

222 else: 

223 if verbose: # pragma no cover 

224 model = onnx.load(onx) 

225 smodel = "\nJSON ONNX\n" + str(model) 

226 else: 

227 smodel = "" 

228 import pprint 

229 raise OnnxBackendAssertionError( # pylint: disable=W0707 

230 "{4} cannot compute the predictions" 

231 " for '{0}' due to {1}{2}\n{3}" 

232 .format(onx, e, smodel, pprint.pformat(inputs), 

233 cls_session)) 

234 except Exception as e: # pragma no cover 

235 raise OnnxBackendAssertionError( # pylint: disable=W0707 

236 "Unable to run onnx '{0}' due to {1}".format(onx, e)) 

237 if verbose: # pragma no cover 

238 print("[compare_runtime] done type={}".format(type(output))) 

239 

240 output0 = output.copy() 

241 

242 if comparable_outputs: 

243 cmp_exp = [load["expected"][o] for o in comparable_outputs] 

244 cmp_out = [output[o] for o in comparable_outputs] 

245 else: 

246 cmp_exp = load["expected"] 

247 cmp_out = output 

248 

249 try: 

250 _compare_expected(cmp_exp, cmp_out, sess, onx, 

251 decimal=decimal, verbose=verbose, 

252 classes=classes, **options) 

253 except ExpectedAssertionError as expe: # pragma no cover 

254 raise expe 

255 except Exception as e: # pragma no cover 

256 if verbose: # pragma no cover 

257 model = onnx.load(onx) 

258 smodel = "\nJSON ONNX\n" + str(model) 

259 else: 

260 smodel = "" 

261 raise OnnxBackendAssertionError( # pylint: disable=W0707 

262 "Model '{}' has discrepencies with cls='{}'.\n{}: {}{}".format( 

263 onx, sess.__class__.__name__, type(e), e, smodel)) 

264 

265 return output0, lambda_onnx