K Nearest Neighbors
K Nearest Neighbors (kNN) is a popular classification algorithm and a member of the family of instance-based learning methods.
using Manifolds, ManifoldML, NearestNeighbors, StatsBase, Statistics
M = SymmetricPositiveDefinite(3)
N = 100
pts = [cov(randn(10, 3)) for _ in 1:N]
ys = [rand([1, 2]) for _ in 1:N]
dist = ManifoldML.RiemannianDistance(M)
point_matrix = reduce(hcat, map(a -> reshape(a, 9), pts))
balltree = BallTree(point_matrix, dist)
function classify(tree, p, ys, k)
num_coords = prod(representation_size(balltree.metric.manifold))
idxs, dists = knn(tree, reshape(p, num_coords), k)
return mode(ys[idxs])
end
classify(balltree, pts[2], ys, 3)
The code performs the following steps:
- Creating random data
pts
on the manifold of symmetric positive definite matrices (i.e. covariance matrices). - Creating random class labels
ys
. - Selecting the distance
dist
to use in kNN. - Collecting coordinates of points in a matrix
point_matrix
. - Creating nearest neighbor search tree using the
NearestNeighbors
package. - Writing a simple kNN classifier that reshapes the point
p
, performs a kNN search fork
nearest neighbors and returns the most common label using the functionmode
fromStatsBase
.
The same general procedure can be followed to build other distance-based classifiers.