Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Complex but recurring testing functions. 

4""" 

5import random 

6import pandas 

7import numpy 

8from numpy.testing import assert_allclose 

9from ..grammar_sklearn import sklearn2graph 

10from ..grammar_sklearn.cc import compile_c_function 

11 

12 

13def iris_data(): 

14 """ 

15 Returns ``(X, y)`` for iris data. 

16 """ 

17 from sklearn.datasets import load_iris 

18 iris = load_iris() 

19 X = iris.data[:, :2] 

20 state = numpy.random.RandomState(seed=34) # pylint: disable=E1101 

21 rnd = state.randn(*X.shape) / 3 

22 X += rnd 

23 y = iris.target 

24 return X, y 

25 

26 

27def check_is_almost_equal(xv, exp, precision=1e-5, message=None): 

28 """ 

29 Checks that two floats or two arrays are almost equal. 

30 

31 @param xv float or vector 

32 @param exp expected value 

33 @param precision precision 

34 @param message additional message 

35 """ 

36 if isinstance(exp, float) or len(exp.ravel()) == 1: 

37 if not (isinstance(xv, float) or len(xv.ravel()) == 1): 

38 raise TypeError( # pragma: no cover 

39 "Type mismatch between {0} and {1} (expected).".format( 

40 type(xv), type(exp))) 

41 diff = abs(xv - exp) 

42 if diff > 1e-5: 

43 raise ValueError( # pragma: no cover 

44 "Predictions are different expected={0}, computed={1}".format( 

45 exp, xv)) 

46 else: 

47 if not isinstance(xv, numpy.ndarray): 

48 raise TypeError( 

49 "Type mismatch between {0} and {1} (expected).".format(type(xv), type(exp))) 

50 xv = xv.ravel() 

51 exp = exp.ravel() 

52 try: 

53 assert_allclose(xv, exp, atol=precision) 

54 except AssertionError as e: 

55 if message is None: 

56 raise e 

57 else: 

58 raise AssertionError(message) from e # pragma: no cover 

59 

60 

61def check_model_representation(model, X, y=None, convs=None, 

62 output_names=None, only_float=True, 

63 verbose=False, suffix="", fLOG=None): 

64 """ 

65 Checks that a trained model can be exported in a specific list 

66 of formats and produces the same outputs if the 

67 representation can be used to predict. 

68 

69 @param model model (a class or an instance of a model but not trained) 

70 @param X features 

71 @param y targets 

72 @param convs list of format to check, all possible by default ``['json', 'c']`` 

73 @param output_names list of output columns 

74 (can be None, a default value is infered based on scikit-learn output then) 

75 @param verbose print some information 

76 @param suffix add this to disambiguate module 

77 @param fLOG logging function 

78 @return function to call to run the prediction 

79 """ 

80 if not only_float: 

81 raise NotImplementedError( # pragma: no cover 

82 "Only float are allowed.") 

83 if isinstance(X, list): 

84 X = pandas.DataFrame(X) 

85 if len(X.shape) != 2: 

86 raise ValueError( # pragma: no cover 

87 "X cannot be converted into a proper DataFrame. It has shape {0}." 

88 "".format(X.shape)) 

89 if only_float: 

90 X = X.values 

91 if isinstance(y, list): 

92 y = numpy.array(y) 

93 if convs is None: 

94 convs = ['json', 'c'] 

95 

96 # sklearn 

97 if not hasattr(model.__class__, "fit"): 

98 # It is a class object and not an instance. 

99 # We use the default values. 

100 model = model() 

101 

102 model.fit(X, y) 

103 h = random.randint(0, X.shape[0] - 1) 

104 if isinstance(X, pandas.DataFrame): 

105 oneX = X.iloc[h, :].astype(numpy.float32) 

106 else: 

107 oneX = X[h, :].ravel().astype(numpy.float32) 

108 

109 # model or transform 

110 moneX = numpy.resize(oneX, (1, len(oneX))) 

111 if hasattr(model, "predict"): 

112 ske = model.predict(moneX) 

113 else: 

114 ske = model.transform(moneX) 

115 

116 if verbose and fLOG: 

117 fLOG("---------------------") 

118 fLOG(type(oneX), oneX.dtype) 

119 fLOG(model) 

120 for k, v in sorted(model.__dict__.items()): 

121 if k[-1] == '_': 

122 fLOG(" {0}={1}".format(k, v)) 

123 fLOG("---------------------") 

124 

125 # grammar 

126 gr = sklearn2graph(model, output_names=output_names) 

127 lot = gr.execute(Features=oneX) 

128 if verbose and fLOG: 

129 fLOG(gr.graph_execution()) 

130 

131 # verification 

132 check_is_almost_equal(lot, ske) 

133 

134 # default for output_names 

135 if output_names is None: 

136 if len(ske.shape) == 1: 

137 output_names = ["Prediction"] 

138 elif len(ske.shape) == 2: 

139 output_names = ["p%d" % i for i in range(ske.shape[1])] 

140 else: 

141 raise ValueError( # pragma: no cover 

142 "Cannot guess default values for output_names.") 

143 

144 for lang in convs: 

145 if lang in ('c', ): 

146 code_c = gr.export(lang=lang)['code'] 

147 if code_c is None: 

148 raise ValueError("cannot be None") # pragma: no cover 

149 

150 compile_fct = compile_c_function 

151 

152 from contextlib import redirect_stdout, redirect_stderr 

153 from io import StringIO 

154 fout = StringIO() 

155 ferr = StringIO() 

156 with redirect_stdout(fout): 

157 with redirect_stderr(ferr): 

158 try: 

159 fct = compile_fct( 

160 code_c, len(output_names), suffix=suffix, fLOG=lambda s: fout.write(s + "\n")) 

161 except Exception as e: # pragma: no cover 

162 raise RuntimeError( 

163 "Unable to compile a code\n-OUT-\n{0}\n-ERR-\n{1}\n-CODE-" 

164 "\n{2}".format(fout.getvalue(), ferr.getvalue(), code_c)) from e 

165 

166 if verbose and fLOG: 

167 fLOG("-----------------") 

168 fLOG(output_names) 

169 fLOG("-----------------") 

170 fLOG(code_c) 

171 fLOG("-----------------") 

172 fLOG("h=", h, "oneX=", oneX) 

173 fLOG("-----------------") 

174 lotc = fct(oneX) 

175 check_is_almost_equal( 

176 lotc, ske, message="Issue with lang='{0}'".format(lang)) 

177 lotc_exp = lotc.copy() 

178 lotc2 = fct(oneX, lotc) 

179 if not numpy.array_equal(lotc_exp, lotc2): 

180 raise ValueError( # pragma: no cover 

181 "Second call returns different results.\n{0}\n{1}".format( 

182 lotc_exp, lotc2)) 

183 else: 

184 ser = gr.export(lang="json", hook={'array': lambda v: v.tolist()}) 

185 if ser is None: 

186 raise ValueError( # pragma: no cover 

187 "No output for long='{0}'".format(lang))