之前写过一篇使用matlab建立Kd-tree并进行k-NN查询。之前写完之后没怎么测试就直接丢上来了,这些天有人后台私信问相关的问题后才发现搜索算法的实现好像有一些问题,而且整体代码比较杂乱,想着重新实现一遍,因此有了这篇。
在网上看了很多人的讲解,具体实现的方法千差万别,有的只在叶节点存放真实数据,在根节点中存放分割信息。有的所有节点都能存放数据信息。有的为了回溯时减少递归次数在节点中存放自身所代表的矩形框大小。在这里我也只实现一种我觉得方便的实现方法,就是每一个节点都存放一个数据点以及基于该数据点的分割信息,每次分割选择方差最大的维度进行分割,另外用绘制函数把2维下的kd-tree绘制出来,这样也更方便可视化了解算法思想。
用matlab来实现这种依赖于指针的树结构老实说并不方便,迫不得已使用数组的索引来代替指针来实现子节点到父节点回溯。后续将自己实现的kd-tree搜索和暴力搜索进行比较,在数据量较小或者数据维度较多时,kd-tree的搜索时间比暴力搜索还要长好多。而且严重依赖于构造出的kd-tree的质量,同样的数据量和数据维度,在几次不同的测试下kd-tree的最近邻搜索耗时可能相差数倍。只是在数据维度小于6维,并且数据量大小大于 2 14 2^{14} 214时,kd-tree的最近邻搜索才开始表现出优势。这很可能是与我的实现方法有关,在matlab上实现主要是了解算法的思想,以及可视化绘图方便,可能对于真实算法性能比较方面只能定性分析,没法准确定量。
kd-tree的最近邻搜索的实现主要参考了李小文的知乎专栏文章:KD Tree的原理及Python实现
运行结果如图:
具体代码如下:
clc
clear
close all
% 数据维度
% 建立kd-tree和最近邻搜索对于维度没限制,只是绘制kd-tree只有数据为2维才有效
DATA_DIM = 2;
% 数据量大小
DATA_SIZE = 20;
data = rand(DATA_SIZE, DATA_DIM);
data = unique(data, 'rows');
% 建立kd-tree前应当归一化数据,保证以方差作为选取划分维度方式的合理性
% 这里每个维度都是0到1的随机数因此不用再归一化
kd_tree = kd_tree_build(data);
if DATA_DIM == 2
figure;
hold on
Limit_X = [0, 1];
Limit_Y = [0, 1];
plot_kd_tree(gca, kd_tree, Limit_X, Limit_Y);
end
% 目标点
target = rand(1, DATA_DIM);
if DATA_DIM == 2
plot(target(1), target(2), '*');
end
% 比较kd-tree搜索和暴力搜索的用时
time_kd = 0;
time_simple = 0;
for i = 1: 10000
tic
res_kd = kd_tree_search(kd_tree, target);
elapsedTime = toc;
time_kd = time_kd + elapsedTime;
tic
res_simple = simple_search(data, target);
elapsedTime = toc;
time_simple = time_simple + elapsedTime;
assert(sqrt(sum((res_kd - res_simple).^2)) < 0.0000001, 'error:kd-tree搜索和暴力搜索结果不一致');
end
disp(['kd-tree搜索用时' num2str(time_kd) '秒']);
disp(['暴力搜索用时' num2str(time_simple) '秒']);
%% 根据data建立kd_tree
function kd_tree = kd_tree_build(data)
kd_tree = [];
kd_tree_build_recurse(data, []);
% 递归子函数
function index_cur = kd_tree_build_recurse(data_cur, index)
index_cur = [];
if size(data_cur, 1) == 0
return
end
index_cur = length(kd_tree) + 1;
kd_tree(index_cur).Dim = [];
kd_tree(index_cur).Cut = [];
kd_tree(index_cur).Value = [];
kd_tree(index_cur).Father = index;
kd_tree(index_cur).Left_son = [];
kd_tree(index_cur).Right_son = [];
% 选取方差最大的维度进行划分,分割值设置为该维度上的中位数
data_var = var(data_cur, 0, 1);
[~, choose_dim] = max(data_var);
data_cur = sortrows(data_cur, choose_dim);
kd_tree(index_cur).Dim = choose_dim;
left_data = data_cur(1: ceil(size(data_cur, 1)/2)-1, :);
right_data = data_cur(ceil(size(data_cur, 1)/2)+1: end, :);
kd_tree(index_cur).Value = data_cur(ceil(size(data_cur, 1)/2), :);
kd_tree(index_cur).Cut = kd_tree(index_cur).Value(choose_dim);
% 递归进行左右子树的构建
kd_tree(index_cur).Left_son = kd_tree_build_recurse(left_data, index_cur);
kd_tree(index_cur).Right_son = kd_tree_build_recurse(right_data, index_cur);
end
end
%% 在坐标区ax上绘制kd_tree
% 绘制函数只支持二维数据
function plot_kd_tree(ax, kd_tree, Limit_X, Limit_Y)
plot_kd_tree_recurse(1, Limit_X, Limit_Y);
% 递归子函数
% 绘制以kd_tree_cur为根节点的子树
function plot_kd_tree_recurse(kd_tree_cur, Limit_X_cur, Limit_Y_cur)
if isempty(kd_tree_cur)
return
end
if ~isempty(kd_tree(kd_tree_cur).Dim)
% 绘制分割线
% 递归进行子树的绘制
if kd_tree(kd_tree_cur).Dim == 1
plot(ax, [kd_tree(kd_tree_cur).Cut, kd_tree(kd_tree_cur).Cut], Limit_Y_cur, 'k-');
plot_kd_tree_recurse(kd_tree(kd_tree_cur).Left_son, [Limit_X_cur(1), kd_tree(kd_tree_cur).Cut], Limit_Y_cur);
plot_kd_tree_recurse(kd_tree(kd_tree_cur).Right_son, [kd_tree(kd_tree_cur).Cut, Limit_X_cur(2)], Limit_Y_cur);
end
if kd_tree(kd_tree_cur).Dim == 2
plot(ax, Limit_X_cur, [kd_tree(kd_tree_cur).Cut, kd_tree(kd_tree_cur).Cut], 'k-');
plot_kd_tree_recurse(kd_tree(kd_tree_cur).Left_son, Limit_X_cur, [Limit_Y_cur(1), kd_tree(kd_tree_cur).Cut]);
plot_kd_tree_recurse(kd_tree(kd_tree_cur).Right_son, Limit_X_cur, [kd_tree(kd_tree_cur).Cut, Limit_Y_cur(2)]);
end
end
% 绘制数据点
plot(ax, kd_tree(kd_tree_cur).Value(1), kd_tree(kd_tree_cur).Value(2), 'r.');
text(ax, kd_tree(kd_tree_cur).Value(1)+0.01, kd_tree(kd_tree_cur).Value(2), num2str(kd_tree_cur));
end
end
%% 查找最近邻的值
function res = kd_tree_search(kd_tree, target)
dist_best = inf;
node_best = sub_kd_tree_search(1);
queue = [1, node_best];
while ~isempty(queue)
root_cur = queue(1, 1);
node_cur = queue(1, 2);
queue(1, :) = [];
while 1
dist_cur = sqrt(sum((target - kd_tree(node_cur).Value).^2));
if dist_cur < dist_best
dist_best = dist_cur;
node_best = node_cur;
end
if node_cur ~= root_cur
node_brother = get_brother(node_cur);
if ~isempty(node_brother)
dist_temp = abs(target(kd_tree(kd_tree(node_cur).Father).Dim) - kd_tree(kd_tree(node_cur).Father).Cut);
if dist_temp < dist_best
queue(end+1, :) = [node_brother, sub_kd_tree_search(node_brother)];
end
end
node_cur = kd_tree(node_cur).Father;
else
break
end
end
end
res = kd_tree(node_best).Value;
% 以当前节点为根节点一直搜索到叶节点
function index = sub_kd_tree_search(index)
while ~isempty(kd_tree(index).Left_son) || ~isempty(kd_tree(index).Right_son)
if isempty(kd_tree(index).Left_son)
index = kd_tree(index).Right_son;
continue
end
if isempty(kd_tree(index).Right_son)
index = kd_tree(index).Left_son;
continue
end
if target(kd_tree(index).Dim) <= kd_tree(index).Cut
index = kd_tree(index).Left_son;
else
index = kd_tree(index).Right_son;
end
end
end
% 获取节点的兄弟节点
function brother_index = get_brother(index)
brother_index = [];
father_index = kd_tree(index).Father;
if isempty(father_index)
return
end
if kd_tree(father_index).Left_son == index
brother_index = kd_tree(father_index).Right_son;
else
brother_index = kd_tree(father_index).Left_son;
end
end
end
%% 暴力搜索
function res = simple_search(data, target)
dist = inf;
res = 0;
for i = 1: size(data, 1)
dist_temp = sqrt(sum((target - data(i,:)).^2));
if dist_temp < dist
dist = dist_temp;
res = i;
end
end
res = data(res, :);
return
end