Coverage for src/mlstatpy/nlp/completion.py: 93%

362 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-27 05:59 +0100

1""" 

2@file 

3@brief About completion 

4""" 

5from typing import Tuple, List, Iterator 

6from collections import deque 

7 

8 

9class CompletionTrieNode: 

10 """ 

11 Node definition in a trie used to do completion, see :ref:`l-completion0`. 

12 This implementation is not very efficient about memmory consumption, 

13 it does not hold above 200.000 words. 

14 It should be done another way (:epkg:`cython`, :epkg:`C++`). 

15 """ 

16 

17 __slots__ = ("value", "children", "weight", 

18 "leave", "stat", "parent", "disp") 

19 

20 def __init__(self, value, leave, weight=1.0, disp=None): 

21 """ 

22 @param value value (a character) 

23 @param leave boolean (is it a completion) 

24 @param weight ordering (the lower, the first) 

25 @param disp original string, use this to identify the node 

26 """ 

27 if not isinstance(value, str): 

28 raise TypeError( 

29 f"value must be str not '{value}' - type={type(value)}") 

30 self.value = value 

31 self.children = None 

32 self.weight = weight 

33 self.leave = leave 

34 self.stat = None 

35 self.parent = None 

36 self.disp = disp 

37 

38 @property 

39 def root(self): 

40 """ 

41 Returns the initial node with no parent. 

42 """ 

43 node = self 

44 while node.parent is not None: 

45 node = node.parent 

46 return node 

47 

48 def __str__(self): 

49 """ 

50 usual 

51 """ 

52 return f"[{'#' if self.leave else '-'}:{self.value}:w={self.weight}]" 

53 

54 def _add(self, key, child): 

55 """ 

56 Adds a child. 

57 

58 @param key one letter of the word 

59 @param child child 

60 @return self 

61 """ 

62 if self.children is None: 

63 self.children = {key: child} 

64 child.parent = self 

65 elif key in self.children: 

66 raise KeyError(f"'{key}' already added") 

67 else: 

68 self.children[key] = child 

69 child.parent = self 

70 return self 

71 

72 def items_list(self) -> List['CompletionTrieNode']: 

73 """ 

74 All children nodes inluding itself in a list. 

75 

76 @return list[ 

77 """ 

78 res = [self] 

79 if self.children is not None: 

80 for _, v in sorted(self.children.items()): 

81 r = v.items_list() 

82 res.extend(r) 

83 return res 

84 

85 def __iter__(self): 

86 """ 

87 Iterates on all nodes (sorted). 

88 """ 

89 stack = [self] 

90 while len(stack) > 0: 

91 node = stack.pop() 

92 yield node 

93 if node.children: 

94 stack.extend(v for k, v in sorted( 

95 node.children.items(), reverse=True)) 

96 

97 def unsorted_iter(self): 

98 """ 

99 Iterates on all nodes. 

100 """ 

101 stack = [self] 

102 while len(stack) > 0: 

103 node = stack.pop() 

104 yield node 

105 if node.children: 

106 stack.extend(node.children.values()) 

107 

108 def items(self) -> Iterator[Tuple[float, str, 'CompletionTrieNode']]: 

109 """ 

110 Iterates on children, iterates on weight, key, child. 

111 """ 

112 if self.children is not None: 

113 for k, v in self.children.items(): 

114 yield v.weight, k, v 

115 

116 def iter_leaves(self, max_weight=None) -> Iterator[Tuple[float, str]]: 

117 """ 

118 Iterators on leaves sorted per weight, yield weight, value. 

119 

120 @param max_weight keep all value under this threshold or None for all 

121 """ 

122 def iter_local(node): 

123 if node.leave and (max_weight is None or node.weight <= max_weight): 

124 yield node.weight, None, node.value 

125 for w, k, v in sorted(node.items()): 

126 for w_, k_, v_ in iter_local(v): 

127 yield w_, k_, v_ 

128 

129 for w, _, v in sorted(iter_local(self)): 

130 yield w, v 

131 

132 def leaves(self) -> Iterator['CompletionTrieNode']: 

133 """ 

134 Iterators on leaves. 

135 """ 

136 stack = [self] 

137 while len(stack) > 0: 

138 pop = stack.pop() 

139 if pop.leave: 

140 yield pop 

141 if pop.children: 

142 stack.extend(pop.children.values()) 

143 

144 def all_completions(self) -> List[Tuple['CompletionTrieNone', List[str]]]: 

145 """ 

146 Retrieves all completions for a node, 

147 the method does not need @see me precompute_stat to be run first. 

148 """ 

149 word = self.value 

150 nodes = [self.root] 

151 node = nodes[0] 

152 for c in word: 

153 if node.children is not None and c in node.children: 

154 node = node.children[c] 

155 nodes.append(node) 

156 nodes.reverse() 

157 all_res = [] 

158 for node in nodes: 

159 res = list(n[1] for n in node.iter_leaves()) 

160 all_res.append((node, res)) 

161 all_res.reverse() 

162 return all_res 

163 

164 def all_mks_completions(self) -> List[Tuple['CompletionTrieNone', List['CompletionTrieNone']]]: 

165 """ 

166 Retrieves all completions for a node, 

167 the method assumes @see me precompute_stat was run. 

168 """ 

169 res = [] 

170 node = self 

171 while True: 

172 res.append((node, node.stat.completions)) 

173 if node.parent is None: 

174 break 

175 node = node.parent 

176 res.reverse() 

177 return res 

178 

179 def str_all_completions(self, maxn=10, use_precompute=True) -> str: 

180 """ 

181 Builds a string with all completions for all 

182 prefixes along the paths. 

183 

184 @param maxn maximum number of completions to show 

185 @param use_precompute use intermediate results built by @see me precompute_stat 

186 @return str 

187 """ 

188 res = self.all_mks_completions() if use_precompute else self.all_completions() 

189 rows = [] 

190 for node, sug in res: 

191 rows.append("l={3} p='{0}' {1} {2}".format(node.value, "-" * 10, node.stat.str_mks(), 

192 '+' if node.leave else '-')) 

193 for i, s in enumerate(sug): 

194 if isinstance(s, str): 

195 rows.append(f" {i + 1}-'{s}'") 

196 else: 

197 rows.append( 

198 f" {i + 1}-w{s[0]}-'{s[1].value}'") 

199 if maxn is not None and i > maxn: 

200 break 

201 return "\n".join(rows) 

202 

203 @staticmethod 

204 def build(words) -> 'CompletionTrieNode': 

205 """ 

206 Builds a trie. 

207 

208 @param words list of ``(word)`` or ``(weight, word)`` or ``(weight, word, display string)`` 

209 @return root of the trie (CompletionTrieNode) 

210 """ 

211 root = CompletionTrieNode('', False) 

212 nb = 0 

213 minw = None 

214 for wword in words: 

215 if isinstance(wword, tuple): 

216 if len(wword) == 2: 

217 w, word = wword 

218 disp = None 

219 elif len(wword) == 3: 

220 w, word, disp = wword 

221 else: 

222 raise ValueError( 

223 f"Unexpected number of values, it should be (weight, word) or (weight, word, dispplay string): {wword}") 

224 else: 

225 w = 1.0 

226 word = wword 

227 disp = None 

228 if w is None: 

229 w = nb 

230 if minw is None or minw > w: 

231 minw = w 

232 node = root 

233 new_node = None 

234 for c in word: 

235 if node.children is not None and c in node.children: 

236 if not node.leave: 

237 node.weight = min(node.weight, w) 

238 node = node.children[c] 

239 else: 

240 new_node = CompletionTrieNode( 

241 node.value + c, False, weight=w) 

242 node._add(c, new_node) 

243 node = new_node 

244 if new_node is None: 

245 if node.leave: 

246 raise ValueError( 

247 f"Value '{word}' appears twice in the input list (not allowed).") 

248 new_node = node 

249 new_node.leave = True 

250 new_node.weight = w 

251 if disp is not None: 

252 new_node.disp = disp 

253 nb += 1 

254 root.weight = minw 

255 return root 

256 

257 def find(self, prefix: str) -> 'CompletionTrieNode': 

258 """ 

259 Returns the node which holds all completions starting with a given prefix. 

260 

261 @param prefix prefix 

262 @return node or None for no result 

263 """ 

264 if len(prefix) == 0: 

265 if not self.value: 

266 return self 

267 else: 

268 raise ValueError( 

269 f"find '{prefix}' but node is not empty '{self.value}'") 

270 node = self 

271 for c in prefix: 

272 if node.children is not None and c in node.children: 

273 node = node.children[c] 

274 else: 

275 return None 

276 return node 

277 

278 def min_keystroke(self, word: str) -> Tuple[int, int]: 

279 """ 

280 Returns the minimum keystrokes for a word without optimisation, 

281 this function should be used if you only have a couple of values to 

282 computes. You shoud use @see me min_keystroke0 to compute all of them. 

283 

284 @param word word 

285 @return number, length of best prefix 

286 

287 See :ref:`l-completion-optim`. 

288 

289 .. math:: 

290 :nowrap: 

291 

292 \\begin{eqnarray*} 

293 K(q, k, S) &=& \\min\\acc{ i | s_i \\succ q[1..k], s_i \\in S } \\\\ 

294 M(q, S) &=& \\min_{0 \\infegal k \\infegal l(q)} k + K(q, k, S) 

295 \\end{eqnarray*} 

296 """ 

297 nodes = [self] 

298 node = self 

299 for c in word: 

300 if node.children is not None and c in node.children: 

301 node = node.children[c] 

302 nodes.append(node) 

303 else: 

304 # not found 

305 return len(word), -1 

306 nodes.reverse() 

307 metric = len(word) 

308 best = len(word) 

309 for node in nodes[1:]: 

310 res = list(n[1] for n in node.iter_leaves()) 

311 ind = res.index(word) 

312 m = len(node.value) + ind + 1 

313 if m < metric: 

314 metric = m 

315 best = len(node.value) 

316 if ind >= len(word): 

317 # no need to go further, the position will increase 

318 break 

319 return metric, best 

320 

321 def min_keystroke0(self, word: str) -> Tuple[int, int]: 

322 """ 

323 Returns the minimum keystrokes for a word. 

324 

325 @param word word 

326 @return number, length of best prefix, iteration it stops moving 

327 

328 This function must be called after @see me precompute_stat 

329 and @see me update_stat_dynamic. 

330 

331 See :ref:`l-completion-optim`. 

332 

333 .. math:: 

334 :nowrap: 

335 

336 \\begin{eqnarray*} 

337 K(q, k, S) &=& \\min\\acc{ i | s_i \\succ q[1..k], s_i \\in S } \\\\ 

338 M(q, S) &=& \\min_{0 \\infegal k \\infegal l(q)} k + K(q, k, S) 

339 \\end{eqnarray*} 

340 """ 

341 node = self.find(word) 

342 if node is None: 

343 raise NotImplementedError( 

344 f"this metric is not yet computed for a query outside the trie: '{word}'") 

345 if not hasattr(node, "stat"): 

346 raise AttributeError("run precompute_stat and update_stat_dynamic") 

347 if not hasattr(node.stat, "mks1"): 

348 raise AttributeError("run precompute_stat and update_stat_dynamic\nnode={0}\n{1}".format( 

349 self, "\n".join(sorted(self.stat.__dict__.keys())))) 

350 return node.stat.mks0, node.stat.mks0_, 0 

351 

352 def min_dynamic_keystroke(self, word: str) -> Tuple[int, int]: 

353 """ 

354 Returns the dynamic minimum keystrokes for a word. 

355 

356 @param word word 

357 @return number, length of best prefix, iteration it stops moving 

358 

359 This function must be called after @see me precompute_stat 

360 and @see me update_stat_dynamic. 

361 See :ref:`Dynamic Minimum Keystroke <def-mks2>`. 

362 

363 .. math:: 

364 :nowrap: 

365 

366 \\begin{eqnarray*} 

367 K(q, k, S) &=& \\min\\acc{ i | s_i \\succ q[1..k], s_i \\in S } \\\\ 

368 M'(q, S) &=& \\min_{0 \\infegal k \\infegal l(q)} \\acc{ M'(q[1..k], S) + K(q, k, S) | q[1..k] \\in S } 

369 \\end{eqnarray*} 

370 """ 

371 node = self.find(word) 

372 if node is None: 

373 raise NotImplementedError( 

374 f"this metric is not yet computed for a query outside the trie: '{word}'") 

375 if not hasattr(node, "stat"): 

376 raise AttributeError("run precompute_stat and update_stat_dynamic") 

377 if not hasattr(node.stat, "mks1"): 

378 raise AttributeError("run precompute_stat and update_stat_dynamic\nnode={0}\n{1}".format( 

379 self, "\n".join(sorted(self.stat.__dict__.keys())))) 

380 return node.stat.mks1, node.stat.mks1_, node.stat.mks1i_ 

381 

382 def min_dynamic_keystroke2(self, word: str) -> Tuple[int, int]: 

383 """ 

384 Returns the modified dynamic minimum keystrokes for a word. 

385 

386 @param word word 

387 @return number, length of best prefix, iteration it stops moving 

388 

389 This function must be called after @see me precompute_stat 

390 and @see me update_stat_dynamic. 

391 See :ref:`Modified Dynamic Minimum Keystroke <def-mks3>`. 

392 

393 .. math:: 

394 :nowrap: 

395 

396 \\begin{eqnarray*} 

397 K(q, k, S) &=& \\min\\acc{ i | s_i \\succ q[1..k], s_i \\in S } \\\\ 

398 M"(q, S) &=& \\min \\left\\{ \\begin{array}{l} 

399 \\min_{1 \\infegal k \\infegal l(q)} \\acc{ M"(q[1..k-1], S) + 1 + K(q, k, S) | q[1..k] \\in S } \\\\ 

400 \\min_{0 \\infegal k \\infegal l(q)} \\acc{ M"(q[1..k], S) + \\delta + K(q, k, S) | q[1..k] \\in S } 

401 \\end{array} \\right . 

402 \\end{eqnarray*} 

403 """ 

404 node = self.find(word) 

405 if node is None: 

406 raise NotImplementedError( 

407 f"this metric is not yet computed for a query outside the trie: '{word}'") 

408 if not hasattr(node, "stat"): 

409 raise AttributeError("run precompute_stat and update_stat_dynamic") 

410 if not hasattr(node.stat, "mks2"): 

411 raise AttributeError("run precompute_stat and update_stat_dynamic\nnode={0}\n{1}".format( 

412 self, "\n".join(sorted(self.stat.__dict__.keys())))) 

413 return node.stat.mks2, node.stat.mks2_, node.stat.mks2i_ 

414 

415 def precompute_stat(self): 

416 """ 

417 Computes and stores list of completions for each node, 

418 computes *mks*. 

419 

420 @param clean clean stat 

421 """ 

422 stack = deque() 

423 stack.extend(self.leaves()) 

424 while len(stack) > 0: 

425 pop = stack.popleft() 

426 if pop.stat is not None: 

427 continue 

428 if not pop.children: 

429 pop.stat = CompletionTrieNode._Stat() 

430 pop.stat.completions = [] 

431 pop.stat.mks0 = len(pop.value) 

432 pop.stat.mks0_ = len(pop.value) 

433 if pop.parent is not None: 

434 stack.append(pop.parent) 

435 elif all(v.stat is not None for v in pop.children.values()): 

436 pop.stat = CompletionTrieNode._Stat() 

437 if pop.leave: 

438 pop.stat.mks0 = len(pop.value) 

439 pop.stat.mks0_ = len(pop.value) 

440 stack.extend(pop.children.values()) 

441 pop.stat.merge_completions(pop.value, pop.children.values()) 

442 pop.stat.next_nodes = pop.children 

443 pop.stat.update_minimum_keystroke(len(pop.value)) 

444 if pop.parent is not None: 

445 stack.append(pop.parent) 

446 else: 

447 # we'll do it again later 

448 stack.append(pop) 

449 

450 def update_stat_dynamic(self, delta=0.8): 

451 """ 

452 Must be called after @see me precompute_stat 

453 and computes dynamic mks (see :ref:`Dynamic Minimum Keystroke <def-mks2>`). 

454 

455 @param delta parameter :math:`\\delta` in defintion 

456 :ref:`Modified Dynamic KeyStroke <def-mks3>` 

457 @return number of iterations to converge 

458 """ 

459 for node in self.unsorted_iter(): 

460 node.stat.init_dynamic_minimum_keystroke(len(node.value)) 

461 node.stat.iter_ = 0 

462 updates = 1 

463 itera = 0 

464 while updates > 0: 

465 updates = 0 

466 stack = [] 

467 stack.append(self) 

468 while len(stack) > 0: 

469 pop = stack.pop() 

470 if pop.stat.iter_ > itera: 

471 continue 

472 updates += pop.stat.update_dynamic_minimum_keystroke( 

473 len(pop.value), delta) 

474 if pop.children: 

475 stack.extend(pop.children.values()) 

476 pop.stat.iter_ += 1 

477 itera += 1 

478 return itera 

479 

480 ## 

481 # end of methods, beginning of subclasses 

482 ## 

483 

484 class _Stat: 

485 """ 

486 Stores statistics and intermediate data about the compuation the metrics. 

487 

488 It contains the following members: 

489 

490 * mks0*: value of minimum keystroke 

491 * mks0_*: length of the prefix to obtain *mks0* 

492 * *mks_iter*: current iteration during the computation of mks 

493 

494 * *mks1*: value of dynamic minimum keystroke 

495 * *mks1_*: length of the prefix to obtain *mks* 

496 * *mks1i_*: iteration when it was obtained 

497 

498 * *mks2*: value of modified dynamic minimum keystroke 

499 * *mks2_*: length of the prefix to obtain *mks2* 

500 * *mks2i*: iteration when it converged 

501 """ 

502 

503 def merge_completions(self, prefix: int, nodes: '[CompletionTrieNode]'): 

504 """ 

505 Merges list of completions and cut the list, we assume 

506 given lists are sorted. 

507 """ 

508 class Fake: 

509 pass 

510 res = [] 

511 indexes = [0 for _ in nodes] 

512 indexes.append(0) 

513 last = Fake() 

514 last.value = None 

515 last.stat = CompletionTrieNode._Stat() 

516 last.stat.completions = list( 

517 sorted((_.weight, _) for _ in nodes if _.leave)) 

518 nodes = list(nodes) 

519 nodes.append(last) 

520 

521 maxl = 0 

522 while True: 

523 en = [(_.stat.completions[indexes[i]][0], i, _.stat.completions[indexes[i]][1]) 

524 for i, _ in enumerate(nodes) if indexes[i] < len(_.stat.completions)] 

525 if not en: 

526 break 

527 e = min(en) 

528 i = e[1] 

529 res.append((e[0], e[2])) 

530 indexes[i] += 1 

531 maxl = max(maxl, len(res[-1][1].value)) 

532 

533 # maxl - len(prefix) represents the longest list which reduces the number of keystrokes 

534 # however, as the method aggregates completions at a lower lovel, 

535 # we must keep longer completions for lower levels 

536 ind = maxl 

537 if len(res) > ind: 

538 self.completions = res[:ind] 

539 else: 

540 self.completions = res 

541 

542 def update_minimum_keystroke(self, lw): 

543 """ 

544 Updates minimum keystroke for the completions. 

545 

546 @param lw prefix length 

547 """ 

548 for i, wsug in enumerate(self.completions): 

549 sug = wsug[1] 

550 nl = lw + i + 1 

551 if not hasattr(sug.stat, "mks0") or sug.stat.mks0 > nl: 

552 sug.stat.mks0 = nl 

553 sug.stat.mks0_ = lw 

554 

555 def update_dynamic_minimum_keystroke(self, lw, delta): 

556 """ 

557 Updates dynamic minimum keystroke for the completions. 

558 

559 @param lw prefix length 

560 @param delta parameter :math:`\\delta` in defintion 

561 :ref:`Modified Dynamic KeyStroke <def-mks3>` 

562 @return number of updates 

563 """ 

564 self.mks_iter += 1 

565 update = 0 

566 for i, wsug in enumerate(self.completions): 

567 sug = wsug[1] 

568 if sug.leave: 

569 # this is a leave so we consider the completion being part 

570 # of the list of completions 

571 nl = self.mks1 + i + 1 

572 if sug.stat.mks1 > nl: 

573 sug.stat.mks1 = nl 

574 sug.stat.mks1_ = lw 

575 sug.stat.mks1i_ = self.mks_iter 

576 update += 1 

577 nl = self.mks2 + i + 1 + delta 

578 if sug.stat.mks2 > nl: 

579 sug.stat.mks2 = nl 

580 sug.stat.mks2_ = lw 

581 sug.stat.mks2i_ = self.mks_iter 

582 update += 1 

583 else: 

584 raise RuntimeError("this case should not happen") 

585 

586 # optimisation of second case of modified metric 

587 # in a separate function for profiling 

588 def second_step(update): 

589 if hasattr(self, "next_nodes"): 

590 for _, child in self.next_nodes.items(): 

591 for i, wsug in enumerate(child.stat.completions): 

592 sug = wsug[1] 

593 if not sug.leave: 

594 continue 

595 nl = self.mks2 + i + 2 

596 if sug.stat.mks2 > nl: 

597 sug.stat.mks2 = nl 

598 sug.stat.mks2_ = lw 

599 sug.stat.mks2i_ = self.mks_iter 

600 update += 1 

601 return update 

602 

603 update = second_step(update) 

604 

605 # finally we need to update mks, mks2 for every prefix 

606 # this is not necessary a leave so it does not appear in the list of completions 

607 # but we need to update mks for these strings, we assume it just 

608 # requires an extra character, somehow, we propagate the values 

609 if hasattr(self, "next_nodes"): 

610 for _, n in self.next_nodes.items(): 

611 if not hasattr(n.stat, "mks1") or n.stat.mks1 > self.mks1 + 1: 

612 n.stat.mks1 = self.mks1 + 1 

613 n.stat.mks1_ = self.mks1_ 

614 n.stat.mks1i_ = self.mks_iter 

615 update += 1 

616 if not hasattr(n.stat, "mks2") or n.stat.mks2 > self.mks2 + 1: 

617 n.stat.mks2 = self.mks2 + 1 

618 n.stat.mks2_ = self.mks2_ 

619 n.stat.mks2i_ = self.mks_iter 

620 update += 1 

621 

622 return update 

623 

624 def init_dynamic_minimum_keystroke(self, lw): 

625 """ 

626 Initializes *mks* and *mks2* from from *mks0*. 

627 

628 @param lw length of the prefix 

629 """ 

630 if hasattr(self, "mks0"): 

631 self.mks1 = self.mks0 

632 self.mks1_ = self.mks0_ 

633 self.mks_iter = 0 

634 self.mks1i_ = 0 

635 self.mks2 = self.mks0 

636 self.mks2_ = self.mks0_ 

637 self.mks2i_ = 0 

638 else: 

639 self.mks0 = lw 

640 self.mks0_ = 0 

641 self.mks1 = lw 

642 self.mks1_ = lw 

643 self.mks_iter = 0 

644 self.mks1i_ = 0 

645 self.mks2 = lw 

646 self.mks2_ = lw 

647 self.mks2i_ = 0 

648 

649 def str_mks0(self) -> str: 

650 """ 

651 Returns a string with metric information. 

652 """ 

653 if hasattr(self, "mks0"): 

654 return f"MKS={self.mks0} *={self.mks0_}" 

655 else: 

656 return "-" 

657 

658 def str_mks(self) -> str: 

659 """ 

660 Returns a string with metric information. 

661 """ 

662 s0 = self.str_mks0() 

663 if hasattr(self, "mks1"): 

664 return s0 + " |'={0} *={1},{2} |\"={3} *={4},{5} |nn={6}".format( 

665 self.mks1, self.mks1_, self.mks1i_, self.mks2, self.mks2i_, self.mks2i_, '+' if hasattr(self, "next_nodes") else '-') 

666 else: 

667 return s0