Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Utilies to test script from :epkg:`scikit-learn` documentation. 

4""" 

5import os 

6from io import StringIO 

7from contextlib import redirect_stdout, redirect_stderr 

8import pprint 

9import numpy 

10from sklearn.base import BaseEstimator 

11from .verify_code import verify_code 

12 

13 

14class MissingVariableError(RuntimeError): 

15 """ 

16 Raised when a variable is missing. 

17 """ 

18 pass 

19 

20 

21def _clean_script(content): 

22 """ 

23 Comments out all lines containing ``.show()``. 

24 """ 

25 new_lines = [] 

26 for line in content.split('\n'): 

27 if '.show()' in line or 'sys.exit' in line: 

28 new_lines.append("# " + line) 

29 else: 

30 new_lines.append(line) 

31 return "\n".join(new_lines) 

32 

33 

34def _enumerate_fit_info(fits): 

35 """ 

36 Extracts the name of the fitted models and the data 

37 used to train it. 

38 """ 

39 for fit in fits: 

40 chs = fit['children'] 

41 if len(chs) < 2: 

42 # unable to extract the needed information 

43 continue # pragma: no cover 

44 model = chs[0]['str'] 

45 if model.endswith('.fit'): 

46 model = model[:-4] 

47 args = [ch['str'] for ch in chs[1:]] 

48 yield model, args 

49 

50 

51def _try_onnx(loc, model_name, args_name, **options): 

52 """ 

53 Tries onnx conversion. 

54 

55 @param loc available variables 

56 @param model_name model name among these variables 

57 @param args_name arguments name among these variables 

58 @param options additional options for the conversion 

59 @return onnx model 

60 """ 

61 from ..onnx_conv import to_onnx 

62 if model_name not in loc: 

63 raise MissingVariableError( # pragma: no cover 

64 "Unable to find model '{}' in {}".format( 

65 model_name, ", ".join(sorted(loc)))) 

66 if args_name[0] not in loc: 

67 raise MissingVariableError( # pragma: no cover 

68 "Unable to find data '{}' in {}".format( 

69 args_name[0], ", ".join(sorted(loc)))) 

70 model = loc[model_name] 

71 X = loc[args_name[0]] 

72 dtype = options.get('dtype', numpy.float32) 

73 Xt = X.astype(dtype) 

74 onx = to_onnx(model, Xt, **options) 

75 args = dict(onx=onx, model=model, X=Xt) 

76 return onx, args 

77 

78 

79def verify_script(file_or_name, try_onnx=True, existing_loc=None, 

80 **options): 

81 """ 

82 Checks that models fitted in an example from :epkg:`scikit-learn` 

83 documentation can be converted into :epkg:`ONNX`. 

84 

85 @param file_or_name file or string 

86 @param try_onnx try the onnx conversion 

87 @param existing_loc existing local variables 

88 @param options conversion options 

89 @return list of converted models 

90 """ 

91 if '\n' not in file_or_name and os.path.exists(file_or_name): 

92 filename = file_or_name 

93 with open(file_or_name, 'r', encoding='utf-8') as f: 

94 content = f.read() 

95 else: # pragma: no cover 

96 content = file_or_name 

97 filename = "<string>" 

98 

99 # comments out .show() 

100 content = _clean_script(content) 

101 

102 # look for fit or predict expressions 

103 _, node = verify_code(content, exc=False) 

104 fits = node._fits 

105 models_args = list(_enumerate_fit_info(fits)) 

106 

107 # execution 

108 obj = compile(content, filename, 'exec') 

109 glo = globals().copy() 

110 loc = {} 

111 if existing_loc is not None: 

112 loc.update(existing_loc) 

113 glo.update(existing_loc) 

114 out = StringIO() 

115 err = StringIO() 

116 

117 with redirect_stdout(out): 

118 with redirect_stderr(err): 

119 exec(obj, glo, loc) # pylint: disable=W0122 

120 

121 # filter out values 

122 cls = (BaseEstimator, numpy.ndarray) 

123 loc_fil = {k: v for k, v in loc.items() if isinstance(v, cls)} 

124 glo_fil = {k: v for k, v in glo.items() if k not in {'__builtins__'}} 

125 onx_info = [] 

126 

127 # onnx 

128 if try_onnx: 

129 if len(models_args) == 0: 

130 raise MissingVariableError( # pragma: no cover 

131 "No detected trained model in '{}'\n{}\n--LOCALS--\n{}".format( 

132 filename, content, pprint.pformat(loc_fil))) 

133 for model_args in models_args: 

134 try: 

135 onx, args = _try_onnx(loc_fil, *model_args, **options) 

136 except MissingVariableError as e: # pragma: no cover 

137 raise MissingVariableError("Unable to find variable in '{}'\n{}".format( 

138 filename, pprint.pformat(fits))) from e 

139 loc_fil[model_args[0] + "_onnx"] = onx 

140 onx_info.append(args) 

141 

142 # final results 

143 return dict(locals=loc_fil, globals=glo_fil, 

144 stdout=out.getvalue(), 

145 stderr=err.getvalue(), 

146 onx_info=onx_info)