Coverage for src/mlstatpy/nlp/completion.py: 93%
362 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-27 05:59 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-27 05:59 +0100
1"""
2@file
3@brief About completion
4"""
5from typing import Tuple, List, Iterator
6from collections import deque
9class CompletionTrieNode:
10 """
11 Node definition in a trie used to do completion, see :ref:`l-completion0`.
12 This implementation is not very efficient about memmory consumption,
13 it does not hold above 200.000 words.
14 It should be done another way (:epkg:`cython`, :epkg:`C++`).
15 """
17 __slots__ = ("value", "children", "weight",
18 "leave", "stat", "parent", "disp")
20 def __init__(self, value, leave, weight=1.0, disp=None):
21 """
22 @param value value (a character)
23 @param leave boolean (is it a completion)
24 @param weight ordering (the lower, the first)
25 @param disp original string, use this to identify the node
26 """
27 if not isinstance(value, str):
28 raise TypeError(
29 f"value must be str not '{value}' - type={type(value)}")
30 self.value = value
31 self.children = None
32 self.weight = weight
33 self.leave = leave
34 self.stat = None
35 self.parent = None
36 self.disp = disp
38 @property
39 def root(self):
40 """
41 Returns the initial node with no parent.
42 """
43 node = self
44 while node.parent is not None:
45 node = node.parent
46 return node
48 def __str__(self):
49 """
50 usual
51 """
52 return f"[{'#' if self.leave else '-'}:{self.value}:w={self.weight}]"
54 def _add(self, key, child):
55 """
56 Adds a child.
58 @param key one letter of the word
59 @param child child
60 @return self
61 """
62 if self.children is None:
63 self.children = {key: child}
64 child.parent = self
65 elif key in self.children:
66 raise KeyError(f"'{key}' already added")
67 else:
68 self.children[key] = child
69 child.parent = self
70 return self
72 def items_list(self) -> List['CompletionTrieNode']:
73 """
74 All children nodes inluding itself in a list.
76 @return list[
77 """
78 res = [self]
79 if self.children is not None:
80 for _, v in sorted(self.children.items()):
81 r = v.items_list()
82 res.extend(r)
83 return res
85 def __iter__(self):
86 """
87 Iterates on all nodes (sorted).
88 """
89 stack = [self]
90 while len(stack) > 0:
91 node = stack.pop()
92 yield node
93 if node.children:
94 stack.extend(v for k, v in sorted(
95 node.children.items(), reverse=True))
97 def unsorted_iter(self):
98 """
99 Iterates on all nodes.
100 """
101 stack = [self]
102 while len(stack) > 0:
103 node = stack.pop()
104 yield node
105 if node.children:
106 stack.extend(node.children.values())
108 def items(self) -> Iterator[Tuple[float, str, 'CompletionTrieNode']]:
109 """
110 Iterates on children, iterates on weight, key, child.
111 """
112 if self.children is not None:
113 for k, v in self.children.items():
114 yield v.weight, k, v
116 def iter_leaves(self, max_weight=None) -> Iterator[Tuple[float, str]]:
117 """
118 Iterators on leaves sorted per weight, yield weight, value.
120 @param max_weight keep all value under this threshold or None for all
121 """
122 def iter_local(node):
123 if node.leave and (max_weight is None or node.weight <= max_weight):
124 yield node.weight, None, node.value
125 for w, k, v in sorted(node.items()):
126 for w_, k_, v_ in iter_local(v):
127 yield w_, k_, v_
129 for w, _, v in sorted(iter_local(self)):
130 yield w, v
132 def leaves(self) -> Iterator['CompletionTrieNode']:
133 """
134 Iterators on leaves.
135 """
136 stack = [self]
137 while len(stack) > 0:
138 pop = stack.pop()
139 if pop.leave:
140 yield pop
141 if pop.children:
142 stack.extend(pop.children.values())
144 def all_completions(self) -> List[Tuple['CompletionTrieNone', List[str]]]:
145 """
146 Retrieves all completions for a node,
147 the method does not need @see me precompute_stat to be run first.
148 """
149 word = self.value
150 nodes = [self.root]
151 node = nodes[0]
152 for c in word:
153 if node.children is not None and c in node.children:
154 node = node.children[c]
155 nodes.append(node)
156 nodes.reverse()
157 all_res = []
158 for node in nodes:
159 res = list(n[1] for n in node.iter_leaves())
160 all_res.append((node, res))
161 all_res.reverse()
162 return all_res
164 def all_mks_completions(self) -> List[Tuple['CompletionTrieNone', List['CompletionTrieNone']]]:
165 """
166 Retrieves all completions for a node,
167 the method assumes @see me precompute_stat was run.
168 """
169 res = []
170 node = self
171 while True:
172 res.append((node, node.stat.completions))
173 if node.parent is None:
174 break
175 node = node.parent
176 res.reverse()
177 return res
179 def str_all_completions(self, maxn=10, use_precompute=True) -> str:
180 """
181 Builds a string with all completions for all
182 prefixes along the paths.
184 @param maxn maximum number of completions to show
185 @param use_precompute use intermediate results built by @see me precompute_stat
186 @return str
187 """
188 res = self.all_mks_completions() if use_precompute else self.all_completions()
189 rows = []
190 for node, sug in res:
191 rows.append("l={3} p='{0}' {1} {2}".format(node.value, "-" * 10, node.stat.str_mks(),
192 '+' if node.leave else '-'))
193 for i, s in enumerate(sug):
194 if isinstance(s, str):
195 rows.append(f" {i + 1}-'{s}'")
196 else:
197 rows.append(
198 f" {i + 1}-w{s[0]}-'{s[1].value}'")
199 if maxn is not None and i > maxn:
200 break
201 return "\n".join(rows)
203 @staticmethod
204 def build(words) -> 'CompletionTrieNode':
205 """
206 Builds a trie.
208 @param words list of ``(word)`` or ``(weight, word)`` or ``(weight, word, display string)``
209 @return root of the trie (CompletionTrieNode)
210 """
211 root = CompletionTrieNode('', False)
212 nb = 0
213 minw = None
214 for wword in words:
215 if isinstance(wword, tuple):
216 if len(wword) == 2:
217 w, word = wword
218 disp = None
219 elif len(wword) == 3:
220 w, word, disp = wword
221 else:
222 raise ValueError(
223 f"Unexpected number of values, it should be (weight, word) or (weight, word, dispplay string): {wword}")
224 else:
225 w = 1.0
226 word = wword
227 disp = None
228 if w is None:
229 w = nb
230 if minw is None or minw > w:
231 minw = w
232 node = root
233 new_node = None
234 for c in word:
235 if node.children is not None and c in node.children:
236 if not node.leave:
237 node.weight = min(node.weight, w)
238 node = node.children[c]
239 else:
240 new_node = CompletionTrieNode(
241 node.value + c, False, weight=w)
242 node._add(c, new_node)
243 node = new_node
244 if new_node is None:
245 if node.leave:
246 raise ValueError(
247 f"Value '{word}' appears twice in the input list (not allowed).")
248 new_node = node
249 new_node.leave = True
250 new_node.weight = w
251 if disp is not None:
252 new_node.disp = disp
253 nb += 1
254 root.weight = minw
255 return root
257 def find(self, prefix: str) -> 'CompletionTrieNode':
258 """
259 Returns the node which holds all completions starting with a given prefix.
261 @param prefix prefix
262 @return node or None for no result
263 """
264 if len(prefix) == 0:
265 if not self.value:
266 return self
267 else:
268 raise ValueError(
269 f"find '{prefix}' but node is not empty '{self.value}'")
270 node = self
271 for c in prefix:
272 if node.children is not None and c in node.children:
273 node = node.children[c]
274 else:
275 return None
276 return node
278 def min_keystroke(self, word: str) -> Tuple[int, int]:
279 """
280 Returns the minimum keystrokes for a word without optimisation,
281 this function should be used if you only have a couple of values to
282 computes. You shoud use @see me min_keystroke0 to compute all of them.
284 @param word word
285 @return number, length of best prefix
287 See :ref:`l-completion-optim`.
289 .. math::
290 :nowrap:
292 \\begin{eqnarray*}
293 K(q, k, S) &=& \\min\\acc{ i | s_i \\succ q[1..k], s_i \\in S } \\\\
294 M(q, S) &=& \\min_{0 \\infegal k \\infegal l(q)} k + K(q, k, S)
295 \\end{eqnarray*}
296 """
297 nodes = [self]
298 node = self
299 for c in word:
300 if node.children is not None and c in node.children:
301 node = node.children[c]
302 nodes.append(node)
303 else:
304 # not found
305 return len(word), -1
306 nodes.reverse()
307 metric = len(word)
308 best = len(word)
309 for node in nodes[1:]:
310 res = list(n[1] for n in node.iter_leaves())
311 ind = res.index(word)
312 m = len(node.value) + ind + 1
313 if m < metric:
314 metric = m
315 best = len(node.value)
316 if ind >= len(word):
317 # no need to go further, the position will increase
318 break
319 return metric, best
321 def min_keystroke0(self, word: str) -> Tuple[int, int]:
322 """
323 Returns the minimum keystrokes for a word.
325 @param word word
326 @return number, length of best prefix, iteration it stops moving
328 This function must be called after @see me precompute_stat
329 and @see me update_stat_dynamic.
331 See :ref:`l-completion-optim`.
333 .. math::
334 :nowrap:
336 \\begin{eqnarray*}
337 K(q, k, S) &=& \\min\\acc{ i | s_i \\succ q[1..k], s_i \\in S } \\\\
338 M(q, S) &=& \\min_{0 \\infegal k \\infegal l(q)} k + K(q, k, S)
339 \\end{eqnarray*}
340 """
341 node = self.find(word)
342 if node is None:
343 raise NotImplementedError(
344 f"this metric is not yet computed for a query outside the trie: '{word}'")
345 if not hasattr(node, "stat"):
346 raise AttributeError("run precompute_stat and update_stat_dynamic")
347 if not hasattr(node.stat, "mks1"):
348 raise AttributeError("run precompute_stat and update_stat_dynamic\nnode={0}\n{1}".format(
349 self, "\n".join(sorted(self.stat.__dict__.keys()))))
350 return node.stat.mks0, node.stat.mks0_, 0
352 def min_dynamic_keystroke(self, word: str) -> Tuple[int, int]:
353 """
354 Returns the dynamic minimum keystrokes for a word.
356 @param word word
357 @return number, length of best prefix, iteration it stops moving
359 This function must be called after @see me precompute_stat
360 and @see me update_stat_dynamic.
361 See :ref:`Dynamic Minimum Keystroke <def-mks2>`.
363 .. math::
364 :nowrap:
366 \\begin{eqnarray*}
367 K(q, k, S) &=& \\min\\acc{ i | s_i \\succ q[1..k], s_i \\in S } \\\\
368 M'(q, S) &=& \\min_{0 \\infegal k \\infegal l(q)} \\acc{ M'(q[1..k], S) + K(q, k, S) | q[1..k] \\in S }
369 \\end{eqnarray*}
370 """
371 node = self.find(word)
372 if node is None:
373 raise NotImplementedError(
374 f"this metric is not yet computed for a query outside the trie: '{word}'")
375 if not hasattr(node, "stat"):
376 raise AttributeError("run precompute_stat and update_stat_dynamic")
377 if not hasattr(node.stat, "mks1"):
378 raise AttributeError("run precompute_stat and update_stat_dynamic\nnode={0}\n{1}".format(
379 self, "\n".join(sorted(self.stat.__dict__.keys()))))
380 return node.stat.mks1, node.stat.mks1_, node.stat.mks1i_
382 def min_dynamic_keystroke2(self, word: str) -> Tuple[int, int]:
383 """
384 Returns the modified dynamic minimum keystrokes for a word.
386 @param word word
387 @return number, length of best prefix, iteration it stops moving
389 This function must be called after @see me precompute_stat
390 and @see me update_stat_dynamic.
391 See :ref:`Modified Dynamic Minimum Keystroke <def-mks3>`.
393 .. math::
394 :nowrap:
396 \\begin{eqnarray*}
397 K(q, k, S) &=& \\min\\acc{ i | s_i \\succ q[1..k], s_i \\in S } \\\\
398 M"(q, S) &=& \\min \\left\\{ \\begin{array}{l}
399 \\min_{1 \\infegal k \\infegal l(q)} \\acc{ M"(q[1..k-1], S) + 1 + K(q, k, S) | q[1..k] \\in S } \\\\
400 \\min_{0 \\infegal k \\infegal l(q)} \\acc{ M"(q[1..k], S) + \\delta + K(q, k, S) | q[1..k] \\in S }
401 \\end{array} \\right .
402 \\end{eqnarray*}
403 """
404 node = self.find(word)
405 if node is None:
406 raise NotImplementedError(
407 f"this metric is not yet computed for a query outside the trie: '{word}'")
408 if not hasattr(node, "stat"):
409 raise AttributeError("run precompute_stat and update_stat_dynamic")
410 if not hasattr(node.stat, "mks2"):
411 raise AttributeError("run precompute_stat and update_stat_dynamic\nnode={0}\n{1}".format(
412 self, "\n".join(sorted(self.stat.__dict__.keys()))))
413 return node.stat.mks2, node.stat.mks2_, node.stat.mks2i_
415 def precompute_stat(self):
416 """
417 Computes and stores list of completions for each node,
418 computes *mks*.
420 @param clean clean stat
421 """
422 stack = deque()
423 stack.extend(self.leaves())
424 while len(stack) > 0:
425 pop = stack.popleft()
426 if pop.stat is not None:
427 continue
428 if not pop.children:
429 pop.stat = CompletionTrieNode._Stat()
430 pop.stat.completions = []
431 pop.stat.mks0 = len(pop.value)
432 pop.stat.mks0_ = len(pop.value)
433 if pop.parent is not None:
434 stack.append(pop.parent)
435 elif all(v.stat is not None for v in pop.children.values()):
436 pop.stat = CompletionTrieNode._Stat()
437 if pop.leave:
438 pop.stat.mks0 = len(pop.value)
439 pop.stat.mks0_ = len(pop.value)
440 stack.extend(pop.children.values())
441 pop.stat.merge_completions(pop.value, pop.children.values())
442 pop.stat.next_nodes = pop.children
443 pop.stat.update_minimum_keystroke(len(pop.value))
444 if pop.parent is not None:
445 stack.append(pop.parent)
446 else:
447 # we'll do it again later
448 stack.append(pop)
450 def update_stat_dynamic(self, delta=0.8):
451 """
452 Must be called after @see me precompute_stat
453 and computes dynamic mks (see :ref:`Dynamic Minimum Keystroke <def-mks2>`).
455 @param delta parameter :math:`\\delta` in defintion
456 :ref:`Modified Dynamic KeyStroke <def-mks3>`
457 @return number of iterations to converge
458 """
459 for node in self.unsorted_iter():
460 node.stat.init_dynamic_minimum_keystroke(len(node.value))
461 node.stat.iter_ = 0
462 updates = 1
463 itera = 0
464 while updates > 0:
465 updates = 0
466 stack = []
467 stack.append(self)
468 while len(stack) > 0:
469 pop = stack.pop()
470 if pop.stat.iter_ > itera:
471 continue
472 updates += pop.stat.update_dynamic_minimum_keystroke(
473 len(pop.value), delta)
474 if pop.children:
475 stack.extend(pop.children.values())
476 pop.stat.iter_ += 1
477 itera += 1
478 return itera
480 ##
481 # end of methods, beginning of subclasses
482 ##
484 class _Stat:
485 """
486 Stores statistics and intermediate data about the compuation the metrics.
488 It contains the following members:
490 * mks0*: value of minimum keystroke
491 * mks0_*: length of the prefix to obtain *mks0*
492 * *mks_iter*: current iteration during the computation of mks
494 * *mks1*: value of dynamic minimum keystroke
495 * *mks1_*: length of the prefix to obtain *mks*
496 * *mks1i_*: iteration when it was obtained
498 * *mks2*: value of modified dynamic minimum keystroke
499 * *mks2_*: length of the prefix to obtain *mks2*
500 * *mks2i*: iteration when it converged
501 """
503 def merge_completions(self, prefix: int, nodes: '[CompletionTrieNode]'):
504 """
505 Merges list of completions and cut the list, we assume
506 given lists are sorted.
507 """
508 class Fake:
509 pass
510 res = []
511 indexes = [0 for _ in nodes]
512 indexes.append(0)
513 last = Fake()
514 last.value = None
515 last.stat = CompletionTrieNode._Stat()
516 last.stat.completions = list(
517 sorted((_.weight, _) for _ in nodes if _.leave))
518 nodes = list(nodes)
519 nodes.append(last)
521 maxl = 0
522 while True:
523 en = [(_.stat.completions[indexes[i]][0], i, _.stat.completions[indexes[i]][1])
524 for i, _ in enumerate(nodes) if indexes[i] < len(_.stat.completions)]
525 if not en:
526 break
527 e = min(en)
528 i = e[1]
529 res.append((e[0], e[2]))
530 indexes[i] += 1
531 maxl = max(maxl, len(res[-1][1].value))
533 # maxl - len(prefix) represents the longest list which reduces the number of keystrokes
534 # however, as the method aggregates completions at a lower lovel,
535 # we must keep longer completions for lower levels
536 ind = maxl
537 if len(res) > ind:
538 self.completions = res[:ind]
539 else:
540 self.completions = res
542 def update_minimum_keystroke(self, lw):
543 """
544 Updates minimum keystroke for the completions.
546 @param lw prefix length
547 """
548 for i, wsug in enumerate(self.completions):
549 sug = wsug[1]
550 nl = lw + i + 1
551 if not hasattr(sug.stat, "mks0") or sug.stat.mks0 > nl:
552 sug.stat.mks0 = nl
553 sug.stat.mks0_ = lw
555 def update_dynamic_minimum_keystroke(self, lw, delta):
556 """
557 Updates dynamic minimum keystroke for the completions.
559 @param lw prefix length
560 @param delta parameter :math:`\\delta` in defintion
561 :ref:`Modified Dynamic KeyStroke <def-mks3>`
562 @return number of updates
563 """
564 self.mks_iter += 1
565 update = 0
566 for i, wsug in enumerate(self.completions):
567 sug = wsug[1]
568 if sug.leave:
569 # this is a leave so we consider the completion being part
570 # of the list of completions
571 nl = self.mks1 + i + 1
572 if sug.stat.mks1 > nl:
573 sug.stat.mks1 = nl
574 sug.stat.mks1_ = lw
575 sug.stat.mks1i_ = self.mks_iter
576 update += 1
577 nl = self.mks2 + i + 1 + delta
578 if sug.stat.mks2 > nl:
579 sug.stat.mks2 = nl
580 sug.stat.mks2_ = lw
581 sug.stat.mks2i_ = self.mks_iter
582 update += 1
583 else:
584 raise RuntimeError("this case should not happen")
586 # optimisation of second case of modified metric
587 # in a separate function for profiling
588 def second_step(update):
589 if hasattr(self, "next_nodes"):
590 for _, child in self.next_nodes.items():
591 for i, wsug in enumerate(child.stat.completions):
592 sug = wsug[1]
593 if not sug.leave:
594 continue
595 nl = self.mks2 + i + 2
596 if sug.stat.mks2 > nl:
597 sug.stat.mks2 = nl
598 sug.stat.mks2_ = lw
599 sug.stat.mks2i_ = self.mks_iter
600 update += 1
601 return update
603 update = second_step(update)
605 # finally we need to update mks, mks2 for every prefix
606 # this is not necessary a leave so it does not appear in the list of completions
607 # but we need to update mks for these strings, we assume it just
608 # requires an extra character, somehow, we propagate the values
609 if hasattr(self, "next_nodes"):
610 for _, n in self.next_nodes.items():
611 if not hasattr(n.stat, "mks1") or n.stat.mks1 > self.mks1 + 1:
612 n.stat.mks1 = self.mks1 + 1
613 n.stat.mks1_ = self.mks1_
614 n.stat.mks1i_ = self.mks_iter
615 update += 1
616 if not hasattr(n.stat, "mks2") or n.stat.mks2 > self.mks2 + 1:
617 n.stat.mks2 = self.mks2 + 1
618 n.stat.mks2_ = self.mks2_
619 n.stat.mks2i_ = self.mks_iter
620 update += 1
622 return update
624 def init_dynamic_minimum_keystroke(self, lw):
625 """
626 Initializes *mks* and *mks2* from from *mks0*.
628 @param lw length of the prefix
629 """
630 if hasattr(self, "mks0"):
631 self.mks1 = self.mks0
632 self.mks1_ = self.mks0_
633 self.mks_iter = 0
634 self.mks1i_ = 0
635 self.mks2 = self.mks0
636 self.mks2_ = self.mks0_
637 self.mks2i_ = 0
638 else:
639 self.mks0 = lw
640 self.mks0_ = 0
641 self.mks1 = lw
642 self.mks1_ = lw
643 self.mks_iter = 0
644 self.mks1i_ = 0
645 self.mks2 = lw
646 self.mks2_ = lw
647 self.mks2i_ = 0
649 def str_mks0(self) -> str:
650 """
651 Returns a string with metric information.
652 """
653 if hasattr(self, "mks0"):
654 return f"MKS={self.mks0} *={self.mks0_}"
655 else:
656 return "-"
658 def str_mks(self) -> str:
659 """
660 Returns a string with metric information.
661 """
662 s0 = self.str_mks0()
663 if hasattr(self, "mks1"):
664 return s0 + " |'={0} *={1},{2} |\"={3} *={4},{5} |nn={6}".format(
665 self.mks1, self.mks1_, self.mks1i_, self.mks2, self.mks2i_, self.mks2i_, '+' if hasattr(self, "next_nodes") else '-')
666 else:
667 return s0