Code source de mlstatpy.ml.kppv_laesa

# -*- coding: utf-8 -*-
"""
Implements optimized k-nn.


:githublink:`%|py|6`
"""
import random
import numpy
from .kppv import NuagePoints


[docs]class NuagePointsLaesa (NuagePoints): """ Implémente l'algorithme des plus proches voisins, version :ref:`LAESA <space_metric_algo_laesa_prime>`_ :githublink:`%|py|15` """
[docs] def __init__(self, nb_pivots): """ Construit la classe :param nb_pivots: number of pivots :githublink:`%|py|22` """ NuagePoints.__init__(self) self.nb_pivots = nb_pivots
[docs] def fit(self, X, y=None): """ Follows sklearn API. :param X: training set :param y: labels :githublink:`%|py|32` """ self.nuage = X self.labels = y self.selection_pivots(self.nb_pivots)
[docs] def selection_pivots(self, nb): """ Sélectionne *nb* pivots aléatoirements. :param nb: nombre de pivots :githublink:`%|py|42` """ nb = min(nb, self.nuage.shape[0]) if nb == 1: self.pivots = [2] else: self.pivots = set() while len(self.pivots) < nb: i = random.randint(0, self.nuage.shape[0] - 1) if i not in self.pivots: self.pivots.add(i) self.pivots = list(sorted(self.pivots)) # on calcule aussi la distance de chaque éléments au pivots self.dist = numpy.zeros((self.nuage.shape[0], len(self.pivots))) for i in range(self.nuage.shape[0]): for j in range(len(self.pivots)): self.dist[i, j] = self.distance( self.nuage[i, :], self.nuage[self.pivots[j], :])
[docs] def ppv(self, obj): """ Retourne l'élément le plus proche de obj et sa distance avec obj, utilise la sélection à l'aide pivots :param obj: object :return: ``tuple(distance, index)`` :githublink:`%|py|68` """ # initialisation dp = [(self.distance(obj, self.nuage[p, :]), p, i) for i, p in enumerate(self.pivots)] # pivots le plus proche dm, im, _ = min(dp) # améliorations for i in range(0, self.nuage.shape[0]): # on regarde si un pivot permet d'éliminer l'élément i calcul = True for d, p, ip in dp: delta = abs(d - self.dist[i, ip]) if delta > dm: calcul = False break # dans le cas contraire on calcule la distance if calcul: d = self.distance(obj, self.nuage[i, :]) if d < dm: dm = d im = i return dm, im