Coverage for src/mlstatpy/ml/kppv.py: 90%
41 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-27 05:59 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-27 05:59 +0100
1# -*- coding: utf-8 -*-
2"""
3@file
4@brief Implements classic k-nn.
5"""
6import numpy
7import numpy.linalg
8from scipy.spatial.distance import euclidean
11class NuagePoints:
12 """
13 Définit une classe de nuage de points.
14 On suppose qu'ils sont définis par une matrice,
15 chaque ligne est un élément.
16 """
18 def __init__(self):
19 """
20 constructeur
21 """
22 pass
24 def fit(self, X, y=None):
25 """
26 Follows sklearn API.
28 @param X training set
29 @param y labels
30 """
31 self.nuage = X
32 self.labels = y
34 def kneighbors(self, X, n_neighbors=1, return_distance=True):
35 """
36 Return the k nearest neighbors.
38 @param X test set
39 @param n_neighbors number of neighbors
40 @param return_distance return distance as well
41 @return array (dist), array (indices)
42 """
43 if n_neighbors != 1:
44 raise NotImplementedError( # pragma: no cover
45 "Not implemented when n_neighbors != 1.")
46 if not return_distance:
47 raise NotImplementedError( # pragma: no cover
48 "Not implemented when return_distance is False.")
50 dist = numpy.zeros(X.shape[0])
51 ind = numpy.zeros(X.shape[0], dtype=numpy.int64)
53 for i in range(X.shape[0]):
54 row = X[i, :]
55 r = self.ppv(row)
56 dist[i], ind[i] = r
57 return dist, ind
59 @property
60 def shape(self):
61 """
62 Retourne la dimension du nuage.
63 """
64 return self.nuage.shape
66 def distance(self, obj1, obj2):
67 """
68 Retourne une distance entre deux éléments.
70 @param obj1 object 1
71 @param obj2 object 2
72 @return distance
73 """
74 try:
75 return euclidean(obj1, obj2)
76 except ValueError as e:
77 raise ValueError(
78 f"Unable to compute euclidean distance with shapes "
79 f"{obj1.shape} and {obj2.shape}.") from e
81 def label(self, i):
82 """
83 Retourne le label de l'object d'indice ``i``.
85 @param i indice
86 @return label or None if there is no label
87 """
88 return self.label[i] if self.label is not None else None
90 def ppv(self, obj):
91 """
92 Retourne l'élément le plus proche de obj et sa distance avec obj.
94 @param obj object
95 @return ``tuple(dist, index)``
96 """
97 if len(obj.shape) == 1:
98 obj = obj.reshape((1, -1))
99 ones = numpy.ones((self.nuage.shape[0], 1))
100 mat = ones @ obj
101 if len(mat.shape) == 1:
102 mat.resize((mat.shape[0], 1))
103 delta = self.nuage - mat
104 norm = numpy.linalg.norm(delta, axis=1)
105 i = numpy.argmin(norm)
106 return norm[i], i