NearestNeighbors是Julia中一个效率比较高的KNN分类统计代码库,它提供了BallTree,KDTree等多种数据结构。
这里使用KDTree结构搜索欧式距离最近的数据, 并绘制图表。这里仍然使用鸢尾花数据
代码示例
using RDatasets
using DataFrames
using CSV
using NearestNeighbors
using Colors
using PyPlot
using PyCall
import PyPlot:plot
import NearestNeighbors.HyperSphere
@pyimport matplotlib.patches as patch
iris = dataset("datasets", "iris"); # load the data
features = collect(Matrix(iris[:, 1:4])'); # features to use for clustering
#要搜索的点
point = features[:,1]
#其他的点
features = features[:,2:end]
#创建树
kdtree = KDTree(features)
#搜索多少个点
k =3
idxs, dists = knn(kdtree, point, k, true)
#搜索到的点的行索引
idxs
# 3-element Array{Int64,1}:
# 17
# 4
# 39
features[:,17]
features[:,4]
features[:,39]
#搜索到的点欧式距离
dists
# 3-element Array{Float64,1}:
# 0.09999999999999998
# 0.1414213562373093
# 0.14142135623730964
data = hcat(features[:,17],features[:,4],features[:,39])
data = data'
other = hcat(features[:,3:16],features[:,1:3],features[:,18:38],features[:,40:end])
other=other'
# 生成颜色图谱,并绘制搜索到的点图
cols = distinguishable_colors(4, RGB(0,0,0))
# 创建图片
cfig = figure()
ax = cfig[:add_subplot](1,1,1)
ax[:set_aspect]("equal")
axis((0.0,9.0,0.0,3.0))
data
for row in eachrow(other)
plot(row[3], row[4], "*",color = (cols[2].r, cols[2].g, cols[2].b))
end
for row in eachrow(data)
plot(row[3], row[4], "*",color = (cols[3].r, cols[3].g, cols[3].b))
end
plot(point[3], point[4], "*",color = (cols[4].r, cols[4].g, cols[4].b))
title("iris")
cfig[:savefig]("iris.png")