Encontrar K vecinos más cercanos y su implementación

Estoy trabajando en la clasificación de datos simples usando KNN con distancia euclidiana. He visto un ejemplo de lo que me gustaría hacer que se hace con la función knnsearch MATLAB, como se muestra a continuación:

 load fisheriris x = meas(:,3:4); gscatter(x(:,1),x(:,2),species) newpoint = [5 1.45]; [n,d] = knnsearch(x,newpoint,'k',10); line(x(n,1),x(n,2),'color',[.5 .5 .5],'marker','o','linestyle','none','markersize',10) 

El código anterior toma un nuevo punto, es decir [5 1.45] y encuentra los 10 valores más cercanos al nuevo punto. ¿Alguien puede mostrarme un algoritmo MATLAB con una explicación detallada de lo que hace la función knnsearch ? Hay alguna otra manera de hacer esto?

La base del algoritmo K-Nearest Neighbourhood (KNN) es que tiene una matriz de datos que consta de N filas y M columnas donde N es la cantidad de puntos de datos que tenemos, mientras que M es la dimensionalidad de cada punto de datos. Por ejemplo, si colocamos coordenadas cartesianas dentro de una matriz de datos, esta suele ser una matriz N x 2 o N x 3 . Con esta matriz de datos, proporciona un punto de consulta y busca los puntos k más cercanos dentro de esta matriz de datos que son los más cercanos a este punto de consulta.

Por lo general, usamos la distancia euclidiana entre la consulta y el rest de sus puntos en su matriz de datos para calcular nuestras distancias. Sin embargo, también se usan otras distancias como L1 o City-Block / Manhattan. Después de esta operación, tendrá N distancias euclidianas o de Manhattan que simbolizan las distancias entre la consulta y cada punto correspondiente en el conjunto de datos. Una vez que encuentre estos, simplemente busque los k puntos más cercanos a la consulta ordenando las distancias en orden ascendente y recuperando esos k puntos que tienen la menor distancia entre su conjunto de datos y la consulta.

Suponiendo que su matriz de datos se almacenó en x , y newpoint es un punto de muestra donde tiene M columnas (es decir, 1 x M ), este es el procedimiento general que seguiría en forma de punto:

  1. Encuentre la distancia euclidiana o de Manhattan entre el punto newpoint y cada punto en x .
  2. Clasifique estas distancias en orden ascendente.
  3. Devuelve los k puntos de datos en x que están más cerca del punto newpoint .

Hagamos cada paso lentamente.


Paso 1

Una forma en que alguien puede hacer esto es tal vez en un ciclo for como sigue:

 N = size(x,1); dists = zeros(N,1); for idx = 1 : N dists(idx) = sqrt(sum((x(idx,:) - newpoint).^2)); end 

Si quisiera implementar la distancia de Manhattan, esto sería simplemente:

 N = size(x,1); dists = zeros(N,1); for idx = 1 : N dists(idx) = sum(abs(x(idx,:) - newpoint)); end 

dists sería un vector de elemento N que contiene las distancias entre cada punto de datos en x y newpoint . Hacemos una resta elemento por elemento entre newpoint y un punto de datos en x , cuadra las diferencias, luego las newpoint todas juntas. Esta sum tiene raíz cuadrada, que completa la distancia euclidiana. Para la distancia de Manhattan, debe realizar un elemento por sustracción de elementos, tomar los valores absolutos y luego sumr todos los componentes. Esta es probablemente la más simple de las implementaciones para entender, pero posiblemente podría ser la más ineficiente … especialmente para conjuntos de datos de mayor tamaño y una mayor dimensionalidad de sus datos.

Otra posible solución sería replicar newpoint y hacer que esta matriz tenga el mismo tamaño que x , luego hacer una resta elemento por elemento de esta matriz, luego sumr todas las columnas de cada fila y hacer la raíz cuadrada. Por lo tanto, podemos hacer algo como esto:

 N = size(x, 1); dists = sqrt(sum((x - repmat(newpoint, N, 1)).^2, 2)); 

Para la distancia de Manhattan, harías:

 N = size(x, 1); dists = sum(abs(x - repmat(newpoint, N, 1)), 2); 

repmat toma una matriz o vector y los repite una cierta cantidad de veces en una dirección dada. En nuestro caso, queremos tomar nuestro vector newpoint y astackr N veces uno encima del otro para crear una matriz N x M , donde cada fila tiene M elementos de longitud. Restamos estas dos matrices juntas, luego cuadramos cada componente. Una vez que hacemos esto, sum todas las columnas de cada fila y finalmente tomamos la raíz cuadrada de todos los resultados. Para la distancia de Manhattan, hacemos la resta, tomamos el valor absoluto y luego summos.

Sin embargo, la forma más eficiente de hacer esto en mi opinión sería usar bsxfun . Esto esencialmente hace la replicación de la que hablamos bajo el capó con una sola llamada de función. Por lo tanto, el código sería simplemente esto:

 dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2)); 

Para mí esto se ve mucho más limpio y al grano. Para la distancia de Manhattan, harías:

 dists = sum(abs(bsxfun(@minus, x, newpoint)), 2); 

Paso 2

Ahora que tenemos nuestras distancias, simplemente las clasificamos. Podemos usar el sort para ordenar nuestras distancias:

 [d,ind] = sort(dists); 

d contendría las distancias ordenadas en orden ascendente, mientras que ind le ind cada valor en la matriz no ordenada donde aparece en el resultado ordenado . Necesitamos usar ind , extraer los primeros k elementos de este vector, luego usar ind para indexar en nuestra matriz de datos x para devolver aquellos puntos que estuvieron más cerca de newpoint .

Paso 3

El último paso es devolver esos k puntos de datos que están más cerca de newpoint . Podemos hacer esto de manera muy simple por:

 ind_closest = ind(1:k); x_closest = x(ind_closest,:); 

ind_closest debe contener los índices en la matriz de datos original x que son los más cercanos a newpoint . Específicamente, ind_closest contiene las filas de las que debe ind_closest muestras en x para obtener los puntos más cercanos a newpoint . x_closest contendrá esos puntos de datos reales.


Para su placer de copiar y pegar, así es como se ve el código:

 dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2)); %// Or do this for Manhattan % dists = sum(abs(bsxfun(@minus, x, newpoint)), 2); [d,ind] = sort(dists); ind_closest = ind(1:k); x_closest = x(ind_closest,:); 

Ejecutando tu ejemplo, veamos nuestro código en acción:

 load fisheriris x = meas(:,3:4); newpoint = [5 1.45]; k = 10; %// Use Euclidean dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2)); [d,ind] = sort(dists); ind_closest = ind(1:k); x_closest = x(ind_closest,:); 

Al inspeccionar ind_closest y x_closest , esto es lo que obtenemos:

 >> ind_closest ind_closest = 120 53 73 134 84 77 78 51 64 87 >> x_closest x_closest = 5.0000 1.5000 4.9000 1.5000 4.9000 1.5000 5.1000 1.5000 5.1000 1.6000 4.8000 1.4000 5.0000 1.7000 4.7000 1.4000 4.7000 1.4000 4.7000 1.5000 

Si ejecutó knnsearch , verá que su variable n coincide con ind_closest . Sin embargo, la variable d devuelve las distancias desde el punto newpoint a cada punto x , no los puntos de datos reales. Si desea las distancias reales, simplemente haga lo siguiente después del código que escribí:

 dist_sorted = d(1:k); 

Tenga en cuenta que la respuesta anterior utiliza solo un punto de consulta en un lote de N ejemplos. Con mucha frecuencia, KNN se usa en múltiples ejemplos simultáneamente. Supongamos que tenemos Q puntos de consulta que queremos probar en el KNN. Esto daría como resultado una matriz kx M x Q donde para cada ejemplo o cada sector, devolvemos los k puntos más cercanos con una dimensionalidad de M Alternativamente, podemos devolver las ID de los k puntos más cercanos, lo que da como resultado una matriz Q xk . Vamos a calcular ambos.

Una forma ingenua de hacer esto sería aplicar el código anterior en un bucle y repetir cada ejemplo.

Algo así funcionaría donde bsxfun una matriz Q xk y aplicamos el bsxfun basado en bsxfun para establecer cada fila de la matriz de salida en los k puntos más cercanos en el conjunto de datos, donde usaremos el conjunto de datos de Fisher Iris tal como lo teníamos antes. También mantendremos la misma dimensionalidad que hicimos en el ejemplo anterior y usaré cuatro ejemplos, entonces Q = 4 y M = 2 :

 %// Load the data and create the query points load fisheriris; x = meas(:,3:4); newpoints = [5 1.45; 7 2; 4 2.5; 2 3.5]; %// Define k and the output matrices Q = size(newpoints, 1); M = size(x, 2); k = 10; x_closest = zeros(k, M, Q); ind_closest = zeros(Q, k); %// Loop through each point and do logic as seen above: for ii = 1 : Q %// Get the point newpoint = newpoints(ii, :); %// Use Euclidean dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2)); [d,ind] = sort(dists); %// New - Output the IDs of the match as well as the points themselves ind_closest(ii, :) = ind(1 : k).'; x_closest(:, :, ii) = x(ind_closest(ii, :), :); end 

Aunque esto es muy bueno, podemos hacerlo aún mejor. Hay una manera de calcular eficientemente la distancia euclidiana al cuadrado entre dos conjuntos de vectores. Lo dejaré como ejercicio si quieres hacer esto con el Manhattan. Consultar este blog , dado que A es una matriz Q1 x M donde cada fila es un punto de dimensionalidad M con puntos Q1 y B es una matriz Q2 x M donde cada fila es también un punto de dimensionalidad M con puntos Q2 , podemos eficientemente calcular una matriz de distancia D(i, j) donde el elemento en la fila i y la columna j denota la distancia entre la fila i de A y la fila j de B usando la siguiente formulación de matriz:

 nA = sum(A.^2, 2); %// Sum of squares for each row of A nB = sum(B.^2, 2); %// Sum of squares for each row of B D = bsxfun(@plus, nA, nB.') - 2*A*B.'; %// Compute distance matrix D = sqrt(D); %// Compute square root to complete calculation 

Por lo tanto, si permitimos que A sea ​​una matriz de puntos de consulta y B el conjunto de datos que consiste en sus datos originales, podemos determinar los k puntos más cercanos ordenando cada fila individualmente y determinando las k ubicaciones de cada fila que fueron las más pequeñas. También podemos usar esto para recuperar los puntos reales.

Por lo tanto:

 %// Load the data and create the query points load fisheriris; x = meas(:,3:4); newpoints = [5 1.45; 7 2; 4 2.5; 2 3.5]; %// Define k and other variables k = 10; Q = size(newpoints, 1); M = size(x, 2); nA = sum(newpoints.^2, 2); %// Sum of squares for each row of A nB = sum(x.^2, 2); %// Sum of squares for each row of B D = bsxfun(@plus, nA, nB.') - 2*newpoints*x.'; %// Compute distance matrix D = sqrt(D); %// Compute square root to complete calculation %// Sort the distances [d, ind] = sort(D, 2); %// Get the indices of the closest distances ind_closest = ind(:, 1:k); %// Also get the nearest points x_closest = permute(reshape(x(ind_closest(:), :).', M, k, []), [2 1 3]); 

Vemos que usamos la lógica para calcular que la matriz de distancia es la misma, pero algunas variables han cambiado para adecuarse al ejemplo. También ordenamos cada fila de forma independiente utilizando las dos versiones de entrada de sort modo que ind contenga los ID por fila d contendrá las distancias correspondientes. Luego determinamos qué índices son los más cercanos a cada punto de consulta simplemente truncando esta matriz a k columnas. Luego usamos permute y permute para determinar cuáles son los puntos más cercanos asociados. Primero usamos todos los índices más cercanos y creamos una matriz de puntos que astack todas las ID una encima de la otra para obtener una matriz Q * kx M . El uso de reshape y permute nos permite crear nuestra matriz 3D para que se convierta en una matriz kx M x Q como la que hemos especificado. Si quisieras obtener las distancias reales, podemos indexar en d y tomar lo que necesitamos. Para hacer esto, necesitarás usar sub2ind para obtener los índices lineales para que podamos indexar en d de una vez. Los valores de ind_closest ya nos dan a qué columnas debemos acceder. Las filas a las que debemos acceder son simplemente 1, k veces, 2, k veces, etc. hasta Q k es la cantidad de puntos que queríamos devolver:

 row_indices = repmat((1:Q).', 1, k); linear_ind = sub2ind(size(d), row_indices, ind_closest); dist_sorted = D(linear_ind); 

Cuando ejecutamos el código anterior para los puntos de consulta anteriores, estos son los índices, puntos y distancias que obtenemos:

 >> ind_closest ind_closest = 120 134 53 73 84 77 78 51 64 87 123 119 118 106 132 108 131 136 126 110 107 62 86 122 71 127 139 115 60 52 99 65 58 94 60 61 80 44 54 72 >> x_closest x_closest(:,:,1) = 5.0000 1.5000 6.7000 2.0000 4.5000 1.7000 3.0000 1.1000 5.1000 1.5000 6.9000 2.3000 4.2000 1.5000 3.6000 1.3000 4.9000 1.5000 6.7000 2.2000 x_closest(:,:,2) = 4.5000 1.6000 3.3000 1.0000 4.9000 1.5000 6.6000 2.1000 4.9000 2.0000 3.3000 1.0000 5.1000 1.6000 6.4000 2.0000 4.8000 1.8000 3.9000 1.4000 x_closest(:,:,3) = 4.8000 1.4000 6.3000 1.8000 4.8000 1.8000 3.5000 1.0000 5.0000 1.7000 6.1000 1.9000 4.8000 1.8000 3.5000 1.0000 4.7000 1.4000 6.1000 2.3000 x_closest(:,:,4) = 5.1000 2.4000 1.6000 0.6000 4.7000 1.4000 6.0000 1.8000 3.9000 1.4000 4.0000 1.3000 4.7000 1.5000 6.1000 2.5000 4.5000 1.5000 4.0000 1.3000 >> dist_sorted dist_sorted = 0.0500 0.1118 0.1118 0.1118 0.1803 0.2062 0.2500 0.3041 0.3041 0.3041 0.3000 0.3162 0.3606 0.4123 0.6000 0.7280 0.9055 0.9487 1.0198 1.0296 0.9434 1.0198 1.0296 1.0296 1.0630 1.0630 1.0630 1.1045 1.1045 1.1180 2.6000 2.7203 2.8178 2.8178 2.8320 2.9155 2.9155 2.9275 2.9732 2.9732 

Para comparar esto con knnsearch , en su lugar, debe especificar una matriz de puntos para el segundo parámetro donde cada fila es un punto de consulta y verá que los índices y las distancias ordenadas coinciden entre esta implementación y knnsearch .


Espero que esto te ayude. ¡Buena suerte!

    Intereting Posts