Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Inspired from skl2onnx, handles two backends. 

4""" 

5from pyquickhelper.pycode import is_travis_or_appveyor 

6from .utils_backend_common_compare import compare_runtime_session 

7from ...tools.ort_wrapper import ( 

8 InferenceSession, GraphOptimizationLevel, SessionOptions) 

9 

10 

11def _capture_output(fct, kind): 

12 if is_travis_or_appveyor(): 

13 return fct(), None, None # pragma: no cover 

14 try: 

15 from cpyquickhelper.io import capture_output 

16 except ImportError: 

17 # cpyquickhelper not available 

18 return fct(), None, None # pragma: no cover 

19 return capture_output(fct, kind) # pragma: no cover 

20 

21 

22class InferenceSession2: 

23 """ 

24 Overwrites class *InferenceSession* to capture 

25 the standard output and error. 

26 """ 

27 

28 def __init__(self, *args, **kwargs): 

29 "Overwrites the constructor." 

30 runtime_options = kwargs.pop('runtime_options', {}) 

31 disable_optimisation = runtime_options.pop( 

32 'disable_optimisation', False) 

33 if disable_optimisation: 

34 if 'sess_options' in kwargs: 

35 raise RuntimeError( # pragma: no cover 

36 "Incompatible options, 'disable_options' and 'sess_options' cannot " 

37 "be sepcified at the same time.") 

38 kwargs['sess_options'] = SessionOptions() 

39 kwargs['sess_options'].graph_optimization_level = ( 

40 GraphOptimizationLevel.ORT_DISABLE_ALL) 

41 self.sess, self.outi, self.erri = _capture_output( 

42 lambda: InferenceSession(*args, **kwargs), 'c') 

43 

44 def run(self, *args, **kwargs): 

45 "Overwrites method *run*." 

46 res, self.outr, self.errr = _capture_output( 

47 lambda: self.sess.run(*args, **kwargs), 'c') 

48 return res 

49 

50 def get_inputs(self, *args, **kwargs): 

51 "Overwrites method *get_inputs*." 

52 return self.sess.get_inputs(*args, **kwargs) 

53 

54 def get_outputs(self, *args, **kwargs): 

55 "Overwrites method *get_outputs*." 

56 return self.sess.get_outputs(*args, **kwargs) 

57 

58 

59def compare_runtime(test, decimal=5, options=None, 

60 verbose=False, context=None, comparable_outputs=None, 

61 intermediate_steps=False, classes=None, 

62 disable_optimisation=False): 

63 """ 

64 The function compares the expected output (computed with 

65 the model before being converted to ONNX) and the ONNX output 

66 produced with module :epkg:`onnxruntime` or :epkg:`mlprodict`. 

67 

68 :param test: dictionary with the following keys: 

69 - *onnx*: onnx model (filename or object) 

70 - *expected*: expected output (filename pkl or object) 

71 - *data*: input data (filename pkl or object) 

72 :param decimal: precision of the comparison 

73 :param options: comparison options 

74 :param context: specifies custom operators 

75 :param verbose: in case of error, the function may print 

76 more information on the standard output 

77 :param comparable_outputs: compare only these outputs 

78 :param intermediate_steps: displays intermediate steps 

79 in case of an error 

80 :param classes: classes names (if option 'nocl' is used) 

81 :param disable_optimisation: disable optimisation onnxruntime 

82 could do 

83 :return: tuple (outut, lambda function to run the predictions) 

84 

85 The function does not return anything but raises an error 

86 if the comparison failed. 

87 """ 

88 return compare_runtime_session( 

89 InferenceSession2, test, decimal=decimal, options=options, 

90 verbose=verbose, context=context, 

91 comparable_outputs=comparable_outputs, 

92 intermediate_steps=intermediate_steps, 

93 classes=classes, disable_optimisation=disable_optimisation)