Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Inspired from skl2onnx, handles two backends. 

4""" 

5from ...onnxrt import OnnxInference 

6from .utils_backend_common_compare import compare_runtime_session 

7 

8 

9class MockVariableName: 

10 "A string." 

11 

12 def __init__(self, name): 

13 self.name = name 

14 

15 @property 

16 def shape(self): 

17 "returns shape" 

18 raise NotImplementedError( # pragma: no cover 

19 "No shape for '{}'.".format(self.name)) 

20 

21 @property 

22 def type(self): 

23 "returns type" 

24 raise NotImplementedError( # pragma: no cover 

25 "No type for '{}'.".format(self.name)) 

26 

27 

28class MockVariableNameShape(MockVariableName): 

29 "A string and a shape." 

30 

31 def __init__(self, name, sh): 

32 MockVariableName.__init__(self, name) 

33 self._shape = sh 

34 

35 @property 

36 def shape(self): 

37 "returns shape" 

38 return self._shape 

39 

40 

41class MockVariableNameShapeType(MockVariableNameShape): 

42 "A string and a shape and a type." 

43 

44 def __init__(self, name, sh, stype): 

45 MockVariableNameShape.__init__(self, name, sh) 

46 self._stype = stype 

47 

48 @property 

49 def type(self): 

50 "returns type" 

51 return self._stype 

52 

53 

54class OnnxInference2(OnnxInference): 

55 "onnxruntime API" 

56 

57 def run(self, name, inputs, *args, **kwargs): # pylint: disable=W0221 

58 "onnxruntime API" 

59 res = OnnxInference.run(self, inputs, **kwargs) 

60 if name is None: 

61 return [res[n] for n in self.output_names] 

62 if name in res: # pragma: no cover 

63 return res[name] 

64 raise RuntimeError( # pragma: no cover 

65 "Unable to find output '{}'.".format(name)) 

66 

67 def get_inputs(self): 

68 "onnxruntime API" 

69 return [MockVariableNameShapeType(*n) for n in self.input_names_shapes_types] 

70 

71 def get_outputs(self): 

72 "onnxruntime API" 

73 return [MockVariableNameShape(*n) for n in self.output_names_shapes] 

74 

75 def run_in_scan(self, inputs, verbose=0, fLOG=None): 

76 "Instance to run in operator scan." 

77 return OnnxInference.run(self, inputs, verbose=verbose, fLOG=fLOG) 

78 

79 

80def compare_runtime(test, decimal=5, options=None, 

81 verbose=False, context=None, comparable_outputs=None, 

82 intermediate_steps=False, classes=None, 

83 disable_optimisation=False): 

84 """ 

85 The function compares the expected output (computed with 

86 the model before being converted to ONNX) and the ONNX output 

87 produced with module :epkg:`onnxruntime` or :epkg:`mlprodict`. 

88 

89 :param test: dictionary with the following keys: 

90 - *onnx*: onnx model (filename or object) 

91 - *expected*: expected output (filename pkl or object) 

92 - *data*: input data (filename pkl or object) 

93 :param decimal: precision of the comparison 

94 :param options: comparison options 

95 :param context: specifies custom operators 

96 :param verbose: in case of error, the function may print 

97 more information on the standard output 

98 :param comparable_outputs: compare only these outputs 

99 :param intermediate_steps: displays intermediate steps 

100 in case of an error 

101 :param classes: classes names (if option 'nocl' is used) 

102 :param disable_optimisation: disable optimisation the runtime may do 

103 :return: tuple (outut, lambda function to run the predictions) 

104 

105 The function does not return anything but raises an error 

106 if the comparison failed. 

107 """ 

108 return compare_runtime_session( 

109 OnnxInference2, test, decimal=decimal, options=options, 

110 verbose=verbose, context=context, 

111 comparable_outputs=comparable_outputs, 

112 intermediate_steps=intermediate_steps, 

113 classes=classes, disable_optimisation=disable_optimisation)