Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Inspired from :epkg:`sklearn-onnx`, handles two backends. 

4""" 

5from .utils_backend_onnxruntime import compare_runtime as compare_runtime_ort 

6from .utils_backend_python import compare_runtime as compare_runtime_pyrt 

7 

8 

9def compare_backend(backend, test, decimal=5, options=None, verbose=False, 

10 context=None, comparable_outputs=None, 

11 intermediate_steps=False, classes=None, 

12 disable_optimisation=False): 

13 """ 

14 The function compares the expected output (computed with 

15 the model before being converted to ONNX) and the ONNX output. 

16 

17 :param backend: backend to use to run the comparison 

18 :param test: dictionary with the following keys: 

19 - *onnx*: onnx model (filename or object) 

20 - *expected*: expected output (filename pkl or object) 

21 - *data*: input data (filename pkl or object) 

22 :param decimal: precision of the comparison 

23 :param options: comparison options 

24 :param context: specifies custom operators 

25 :param comparable_outputs: compare only these outputs 

26 :param verbose: in case of error, the function may print 

27 more information on the standard output 

28 :param intermediate_steps: displays intermediate steps 

29 in case of an error 

30 :param classes: classes names (if option 'nocl' is used) 

31 :param disable_optimisation: disable optimisation onnxruntime 

32 could do 

33 

34 The function does not return anything but raises an error 

35 if the comparison failed. 

36 :return: tuple (output, lambda function to call onnx predictions) 

37 """ 

38 if backend == "onnxruntime": 

39 return compare_runtime_ort( 

40 test, decimal, options=options, verbose=verbose, 

41 comparable_outputs=comparable_outputs, 

42 intermediate_steps=False, classes=classes, 

43 disable_optimisation=disable_optimisation) 

44 if backend == "python": 

45 return compare_runtime_pyrt( 

46 test, decimal, options=options, verbose=verbose, 

47 comparable_outputs=comparable_outputs, 

48 intermediate_steps=intermediate_steps, classes=classes, 

49 disable_optimisation=disable_optimisation) 

50 raise ValueError( # pragma: no cover 

51 "Does not support backend '{0}'.".format(backend))