Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- coding: utf-8 -*-
2"""
3@file
4@brief Implements optimized k-nn.
5"""
6import random
7import numpy
8from .kppv import NuagePoints
11class NuagePointsLaesa (NuagePoints):
12 """
13 Implémente l'algorithme des plus proches voisins,
14 version :ref:`LAESA <space_metric_algo_laesa_prime>`_
15 """
17 def __init__(self, nb_pivots):
18 """
19 Construit la classe
21 @param nb_pivots number of pivots
22 """
23 NuagePoints.__init__(self)
24 self.nb_pivots = nb_pivots
26 def fit(self, X, y=None):
27 """
28 Follows sklearn API.
30 @param X training set
31 @param y labels
32 """
33 self.nuage = X
34 self.labels = y
35 self.selection_pivots(self.nb_pivots)
37 def selection_pivots(self, nb):
38 """
39 Sélectionne *nb* pivots aléatoirements.
41 @param nb nombre de pivots
42 """
43 nb = min(nb, self.nuage.shape[0])
44 if nb == 1:
45 self.pivots = [2]
46 else:
47 self.pivots = set()
48 while len(self.pivots) < nb:
49 i = random.randint(0, self.nuage.shape[0] - 1)
50 if i not in self.pivots:
51 self.pivots.add(i)
52 self.pivots = list(sorted(self.pivots))
54 # on calcule aussi la distance de chaque éléments au pivots
55 self.dist = numpy.zeros((self.nuage.shape[0], len(self.pivots)))
56 for i in range(self.nuage.shape[0]):
57 for j in range(len(self.pivots)): # pylint: disable=C0200
58 self.dist[i, j] = self.distance(
59 self.nuage[i, :], self.nuage[self.pivots[j], :])
61 def ppv(self, obj):
62 """
63 Retourne l'élément le plus proche de obj et sa distance avec obj,
64 utilise la sélection à l'aide pivots
66 @param obj object
67 @return ``tuple(distance, index)``
68 """
70 # initialisation
71 dp = [(self.distance(obj, self.nuage[p, :]), p, i)
72 for i, p in enumerate(self.pivots)]
74 # pivots le plus proche
75 dm, im, _ = min(dp)
77 # améliorations
78 for i in range(0, self.nuage.shape[0]):
80 # on regarde si un pivot permet d'éliminer l'élément i
81 calcul = True
82 for d, p, ip in dp:
83 delta = abs(d - self.dist[i, ip])
84 if delta > dm:
85 calcul = False
86 break
88 # dans le cas contraire on calcule la distance
89 if calcul:
90 d = self.distance(obj, self.nuage[i, :])
91 if d < dm:
92 dm = d
93 im = i
95 return dm, im