NearestNeighbors是Julia中一个效率比较高的KNN分类统计代码库,它提供了BallTree,KDTree等多种数据结构。
这里使用BallTree结构,并绘制图表。这里仍然使用鸢尾花数据
代码示例
using RDatasets
using DataFrames
using CSV
using NearestNeighbors
using Colors
using PyPlot
using PyCall
import PyPlot:plot
import NearestNeighbors.HyperSphere
@pyimport matplotlib.patches as patch
# y = load("D:/leaning/Julia/pkg-other/Rdatasets/csv/datasets/iris.csv") |> DataFrame
iris = dataset("datasets", "iris"); # load the data
# iris = DataFrame(CSV.File(joinpath(dirname(pathof(DataFrames)),"D:/leaning/Julia/pkg-other/Rdatasets/csv/datasets/iris.csv")));
show(iris)
iris[:, 1:4]
features = collect(Matrix(iris[:, 1:4])'); # features to use for clustering
tree = BallTree(features, Euclidean(); leafsize = 50)
# 跳过非叶子节点
offset = tree.tree_data.n_internal_nodes + 1
nleafs = tree.tree_data.n_leafs
# 叶子节点的范围
index_range = offset: offset + nleafs - 1
# 生成颜色图谱
cols = distinguishable_colors(length(index_range), RGB(0,0,0))
# 创建图片
cfig = figure()
ax = cfig[:add_subplot](1,1,1)
ax[:set_aspect]("equal")
axis((2.5,9.0,1.0,5.0))
# 坐标上添加一个圆
function add_sphere(ax, hs::HyperSphere, col)
ell = patch.Circle(hs.center, radius = hs.r, facecolor="none", edgecolor=col)
ax[:add_artist](ell)
end
for (i, idx) = enumerate(index_range)
col = cols[i]
# 获取决策树中的叶子节点
range = NearestNeighbors.get_leaf_range(tree.tree_data, idx)
d = tree.data[range]
for idex in 1:length(d)
point = collect(d[idex])
# 先画点
plot(vec(point[1,:]), vec(point[2,:]), "*", color = (col.r, col.g, col.b))
end
# 设置圆
sphere = tree.hyper_spheres[idx]
add_sphere(ax, sphere, (col.r, col.g, col.b))
end
title("Leaf nodes with their corresponding points")
cfig[:savefig]("iris.png")
绘制的图表