Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Validates runtime for many :scikit-learn: operators. 

4The submodule relies on :epkg:`onnxconverter_common`, 

5:epkg:`sklearn-onnx`. 

6""" 

7import numpy 

8from skl2onnx.common.data_types import ( 

9 FloatTensorType, DoubleTensorType) 

10 

11 

12text_alpha_num = [ 

13 ('zero', 0), 

14 ('one', 1), 

15 ('two', 2), 

16 ('three', 3), 

17 ('four', 4), 

18 ('five', 5), 

19 ('six', 6), 

20 ('seven', 7), 

21 ('eight', 8), 

22 ('nine', 9), 

23 ('dix', 10), 

24 ('eleven', 11), 

25 ('twelve', 12), 

26 ('thirteen', 13), 

27 ('fourteen', 14), 

28 ('fifteen', 15), 

29 ('sixteen', 16), 

30 ('seventeen', 17), 

31 ('eighteen', 18), 

32 ('nineteen', 19), 

33 ('twenty', 20), 

34 ('twenty one', 21), 

35 ('twenty two', 22), 

36 ('twenty three', 23), 

37 ('twenty four', 24), 

38 ('twenty five', 25), 

39 ('twenty six', 26), 

40 ('twenty seven', 27), 

41 ('twenty eight', 28), 

42 ('twenty nine', 29), 

43] 

44 

45 

46def _guess_noshape(obj, shape): 

47 if isinstance(obj, numpy.ndarray): 

48 if obj.dtype == numpy.float32: 

49 return FloatTensorType(shape) # pragma: no cover 

50 if obj.dtype == numpy.float64: 

51 return DoubleTensorType(shape) 

52 raise NotImplementedError( # pragma: no cover 

53 "Unable to process object(1) [{}].".format(obj)) 

54 raise NotImplementedError( # pragma: no cover 

55 "Unable to process object(2) [{}].".format(obj)) 

56 

57 

58def _noshapevar(fct): 

59 

60 def process_itt(itt, Xort): 

61 if isinstance(itt, tuple): 

62 return (process_itt(itt[0], Xort), itt[1]) 

63 

64 # name = "V%s_" % str(id(Xort))[:5] 

65 new_itt = [] 

66 for a, b in itt: 

67 # shape = [name + str(i) for s in b.shape] 

68 shape = [None for s in b.shape] 

69 new_itt.append((a, _guess_noshape(b, shape))) 

70 return new_itt 

71 

72 def new_fct(**kwargs): 

73 X, y, itt, meth, mo, Xort = fct(**kwargs) 

74 new_itt = process_itt(itt, Xort) 

75 return X, y, new_itt, meth, mo, Xort 

76 return new_fct 

77 

78 

79def _1d_problem(fct): 

80 

81 def new_fct(**kwargs): 

82 n_features = kwargs.get('n_features', None) 

83 if n_features not in (None, 1): 

84 raise RuntimeError( # pragma: no cover 

85 "Misconfiguration: the number of features must not be " 

86 "specified for a 1D problem.") 

87 X, y, itt, meth, mo, Xort = fct(**kwargs) 

88 new_itt = itt # process_itt(itt, Xort) 

89 X = X[:, 0] 

90 return X, y, new_itt, meth, mo, Xort 

91 return new_fct