Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Inspired from skl2onnx, handles two backends. 

4""" 

5import numpy 

6from ...tools.asv_options_helper import get_opset_number_from_onnx 

7from .utils_backend_onnxruntime import _capture_output 

8 

9 

10from .tests_helper import ( # noqa 

11 binary_array_to_string, 

12 dump_data_and_model, 

13 dump_one_class_classification, 

14 dump_binary_classification, 

15 dump_multilabel_classification, 

16 dump_multiple_classification, 

17 dump_multiple_regression, 

18 dump_single_regression, 

19 convert_model, 

20 fit_classification_model, 

21 fit_classification_model_simple, 

22 fit_multilabel_classification_model, 

23 fit_regression_model) 

24 

25 

26def create_tensor(N, C, H=None, W=None): 

27 "Creates a tensor." 

28 if H is None and W is None: 

29 return numpy.random.rand(N, C).astype(numpy.float32, copy=False) # pylint: disable=E1101 

30 elif H is not None and W is not None: 

31 return numpy.random.rand(N, C, H, W).astype(numpy.float32, copy=False) # pylint: disable=E1101 

32 raise ValueError( # pragma no cover 

33 'This function only produce 2-D or 4-D tensor.') 

34 

35 

36def _get_ir_version(opv): 

37 if opv >= 12: 

38 return 7 

39 if opv >= 11: # pragma no cover 

40 return 6 

41 if opv >= 10: # pragma no cover 

42 return 5 

43 if opv >= 9: # pragma no cover 

44 return 4 

45 if opv >= 8: # pragma no cover 

46 return 4 

47 return 3 # pragma no cover 

48 

49 

50TARGET_OPSET = get_opset_number_from_onnx() 

51TARGET_IR = _get_ir_version(TARGET_OPSET)