Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- coding: utf-8 -*- 

2""" 

3@file 

4@brief Implements different methods to split a dataframe. 

5""" 

6import hashlib 

7import pickle 

8import random 

9import warnings 

10from io import StringIO 

11import pandas 

12 

13 

14def sklearn_train_test_split(self, path_or_buf=None, export_method="to_csv", 

15 names=None, **kwargs): 

16 """ 

17 Randomly splits a dataframe into smaller pieces. 

18 The function returns streams of file names. 

19 The function relies on :epkg:`sklearn:model_selection:train_test_split`. 

20 It does not handle stratified version of it. 

21 

22 @param self @see cl StreamingDataFrame 

23 @param path_or_buf a string, a list of strings or buffers, if it is a 

24 string, it must contain ``{}`` like ``partition{}.txt`` 

25 @param export_method method used to store the partitions, by default 

26 :epkg:`pandas:DataFrame:to_csv` 

27 @param names partitions names, by default ``('train', 'test')`` 

28 @param kwargs parameters for the export function and 

29 :epkg:`sklearn:model_selection:train_test_split`. 

30 @return outputs of the exports functions 

31 

32 The function cannot return two iterators or two 

33 @see cl StreamingDataFrame because running through one 

34 means running through the other. We can assume both 

35 splits do not hold in memory and we cannot run through 

36 the same iterator again as random draws would be different. 

37 We need to store the results into files or buffers. 

38 

39 .. warning:: 

40 The method *export_method* must write the data in 

41 mode *append* and allows stream. 

42 """ 

43 if kwargs.get("stratify") is not None: 

44 raise NotImplementedError( 

45 "No implementation yet for the stratified version.") 

46 with warnings.catch_warnings(): 

47 warnings.filterwarnings("ignore", category=ImportWarning) 

48 from sklearn.model_selection import train_test_split # pylint: disable=C0415 

49 

50 opts = ['test_size', 'train_size', 

51 'random_state', 'shuffle', 'stratify'] 

52 split_ops = {} 

53 for o in opts: 

54 if o in kwargs: 

55 split_ops[o] = kwargs[o] 

56 del kwargs[o] 

57 

58 exportf_ = getattr(pandas.DataFrame, export_method) 

59 if export_method == 'to_csv' and 'mode' not in kwargs: 

60 exportf = lambda *a, **kw: exportf_(*a, mode='a', **kw) 

61 else: 

62 exportf = exportf_ 

63 

64 if isinstance(path_or_buf, str): 

65 if "{}" not in path_or_buf: 

66 raise ValueError( 

67 "path_or_buf must contain {} to insert the partition name") 

68 if names is None: 

69 names = ['train', 'test'] 

70 elif len(names) != len(path_or_buf): 

71 raise ValueError( # pragma: no cover 

72 'names and path_or_buf must have the same length') 

73 path_or_buf = [path_or_buf.format(n) for n in names] 

74 elif path_or_buf is None: 

75 path_or_buf = [None, None] 

76 else: 

77 if not isinstance(path_or_buf, list): 

78 raise TypeError( # pragma: no cover 

79 'path_or_buf must be a list or a string') 

80 

81 bufs = [] 

82 close = [] 

83 for p in path_or_buf: 

84 if p is None: 

85 st = StringIO() 

86 cl = False 

87 elif isinstance(p, str): 

88 st = open( # pylint: disable=R1732 

89 p, "w", encoding=kwargs.get('encoding')) 

90 cl = True 

91 else: 

92 st = p 

93 cl = False 

94 bufs.append(st) 

95 close.append(cl) 

96 

97 for df in self: 

98 train, test = train_test_split(df, **split_ops) 

99 exportf(train, bufs[0], **kwargs) 

100 exportf(test, bufs[1], **kwargs) 

101 kwargs['header'] = False 

102 

103 for b, c in zip(bufs, close): 

104 if c: 

105 b.close() 

106 return [st.getvalue() if isinstance(st, StringIO) else p 

107 for st, p in zip(bufs, path_or_buf)] 

108 

109 

110def sklearn_train_test_split_streaming(self, test_size=0.25, train_size=None, 

111 stratify=None, hash_size=9, unique_rows=False): 

112 """ 

113 Randomly splits a dataframe into smaller pieces. 

114 The function returns streams of file names. 

115 The function relies on :epkg:`sklearn:model_selection:train_test_split`. 

116 It handles the stratified version of it. 

117 

118 @param self @see cl StreamingDataFrame 

119 @param test_size ratio for the test partition (if *train_size* is not specified) 

120 @param train_size ratio for the train partition 

121 @param stratify column holding the stratification 

122 @param hash_size size of the hash to cache information about partition 

123 @param unique_rows ensures that rows are unique 

124 @return Two @see cl StreamingDataFrame, one 

125 for train, one for test. 

126 

127 The function returns two iterators or two 

128 @see cl StreamingDataFrame. It 

129 tries to do everything without writing anything on disk 

130 but it requires to store the repartition somehow. 

131 This function hashes every row and maps the hash with a part 

132 (train or test). This cache must hold in memory otherwise the 

133 function fails. The two returned iterators must not be used 

134 for the first time in the same time. The first time is used to 

135 build the cache. The function changes the order of rows if 

136 the parameter *stratify* is not null. The cache has a side effect: 

137 every exact same row will be put in the same partition. 

138 If that is not what you want, you should add an index column 

139 or a random one. 

140 """ 

141 p = (1 - test_size) if test_size else None 

142 if train_size is not None: 

143 p = train_size 

144 n = 2 * max(1 / p, 1 / (1 - p)) # changement 

145 

146 static_schema = [] 

147 

148 def iterator_rows(): 

149 "iterates on rows" 

150 counts = {} 

151 memory = {} 

152 pos_col = None 

153 for df in self: 

154 if pos_col is None: 

155 static_schema.append(list(df.columns)) 

156 static_schema.append(list(df.dtypes)) 

157 static_schema.append(df.shape[0]) 

158 if stratify is not None: 

159 pos_col = list(df.columns).index(stratify) 

160 else: 

161 pos_col = -1 

162 

163 for obs in df.itertuples(index=False, name=None): 

164 strat = 0 if stratify is None else obs[pos_col] 

165 if strat not in memory: 

166 memory[strat] = [] 

167 memory[strat].append(obs) 

168 

169 for k, v in memory.items(): 

170 if len(v) >= n + random.randint(0, 10): # changement 

171 vr = list(range(len(v))) 

172 # on permute aléatoirement 

173 random.shuffle(vr) 

174 if (0, k) in counts: 

175 tt = counts[1, k] + counts[0, k] 

176 delta = - int(counts[0, k] - tt * p + 0.5) 

177 else: 

178 delta = 0 

179 i = int(len(v) * p + 0.5) 

180 i += delta 

181 i = max(0, min(len(v), i)) 

182 one = set(vr[:i]) 

183 for d, obs_ in enumerate(v): 

184 yield obs_, 0 if d in one else 1 

185 if (0, k) not in counts: 

186 counts[0, k] = i 

187 counts[1, k] = len(v) - i 

188 else: 

189 counts[0, k] += i 

190 counts[1, k] += len(v) - i 

191 # on efface de la mémoire les informations produites 

192 v.clear() 

193 

194 # Lorsqu'on a fini, il faut tout de même répartir les 

195 # observations stockées. 

196 for k, v in memory.items(): 

197 vr = list(range(len(v))) 

198 # on permute aléatoirement 

199 random.shuffle(vr) 

200 if (0, k) in counts: 

201 tt = counts[1, k] + counts[0, k] 

202 delta = - int(counts[0, k] - tt * p + 0.5) 

203 else: 

204 delta = 0 

205 i = int(len(v) * p + 0.5) 

206 i += delta 

207 i = max(0, min(len(v), i)) 

208 one = set(vr[:i]) 

209 for d, obs in enumerate(v): 

210 yield obs, 0 if d in one else 1 

211 if (0, k) not in counts: 

212 counts[0, k] = i 

213 counts[1, k] = len(v) - i 

214 else: 

215 counts[0, k] += i 

216 counts[1, k] += len(v) - i 

217 

218 def h11(w): 

219 "pickle and hash" 

220 b = pickle.dumps(w) 

221 return hashlib.md5(b).hexdigest()[:hash_size] 

222 

223 # We store the repartition in a cache. 

224 cache = {} 

225 

226 def iterator_internal(part_requested): 

227 "internal iterator on dataframes" 

228 iy = 0 

229 accumul = [] 

230 if len(cache) == 0: 

231 for obs, part in iterator_rows(): 

232 h = h11(obs) 

233 if unique_rows and h in cache: 

234 raise ValueError( # pragma: no cover 

235 "A row or at least its hash is already cached. " 

236 "Increase hash_size or check for duplicates " 

237 "('{0}')\n{1}.".format(h, obs)) 

238 if h not in cache: 

239 cache[h] = part 

240 else: 

241 part = cache[h] 

242 if part == part_requested: 

243 accumul.append(obs) 

244 if len(accumul) >= static_schema[2]: 

245 dfo = pandas.DataFrame( 

246 accumul, columns=static_schema[0]) 

247 self.ensure_dtype(dfo, static_schema[1]) 

248 iy += dfo.shape[0] 

249 accumul.clear() 

250 yield dfo 

251 else: 

252 for df in self: 

253 for obs in df.itertuples(index=False, name=None): 

254 h = h11(obs) 

255 part = cache.get(h) 

256 if part is None: 

257 raise ValueError( # pragma: no cover 

258 "Second iteration. A row was never met in the first one\n{0}".format(obs)) 

259 if part == part_requested: 

260 accumul.append(obs) 

261 if len(accumul) >= static_schema[2]: 

262 dfo = pandas.DataFrame( 

263 accumul, columns=static_schema[0]) 

264 self.ensure_dtype(dfo, static_schema[1]) 

265 iy += dfo.shape[0] 

266 accumul.clear() 

267 yield dfo 

268 if len(accumul) > 0: 

269 dfo = pandas.DataFrame(accumul, columns=static_schema[0]) 

270 self.ensure_dtype(dfo, static_schema[1]) 

271 iy += dfo.shape[0] 

272 yield dfo 

273 

274 return (self.__class__(lambda: iterator_internal(0)), 

275 self.__class__(lambda: iterator_internal(1)))