使用matlab建立Kd-tree及最近邻查询

之前写过一篇使用matlab建立Kd-tree并进行k-NN查询。之前写完之后没怎么测试就直接丢上来了,这些天有人后台私信问相关的问题后才发现搜索算法的实现好像有一些问题,而且整体代码比较杂乱,想着重新实现一遍,因此有了这篇。

在网上看了很多人的讲解,具体实现的方法千差万别,有的只在叶节点存放真实数据,在根节点中存放分割信息。有的所有节点都能存放数据信息。有的为了回溯时减少递归次数在节点中存放自身所代表的矩形框大小。在这里我也只实现一种我觉得方便的实现方法,就是每一个节点都存放一个数据点以及基于该数据点的分割信息,每次分割选择方差最大的维度进行分割,另外用绘制函数把2维下的kd-tree绘制出来,这样也更方便可视化了解算法思想。

在这里插入图片描述用matlab来实现这种依赖于指针的树结构老实说并不方便,迫不得已使用数组的索引来代替指针来实现子节点到父节点回溯。后续将自己实现的kd-tree搜索和暴力搜索进行比较,在数据量较小或者数据维度较多时,kd-tree的搜索时间比暴力搜索还要长好多。而且严重依赖于构造出的kd-tree的质量,同样的数据量和数据维度,在几次不同的测试下kd-tree的最近邻搜索耗时可能相差数倍。只是在数据维度小于6维,并且数据量大小大于 2 14 2^{14} 214时,kd-tree的最近邻搜索才开始表现出优势。这很可能是与我的实现方法有关,在matlab上实现主要是了解算法的思想,以及可视化绘图方便,可能对于真实算法性能比较方面只能定性分析,没法准确定量。

kd-tree的最近邻搜索的实现主要参考了李小文的知乎专栏文章:KD Tree的原理及Python实现

运行结果如图:
在这里插入图片描述

具体代码如下:

clc
clear
close all

% 数据维度
% 建立kd-tree和最近邻搜索对于维度没限制,只是绘制kd-tree只有数据为2维才有效
DATA_DIM = 2;
% 数据量大小
DATA_SIZE = 20;

data = rand(DATA_SIZE, DATA_DIM);
data = unique(data, 'rows');

% 建立kd-tree前应当归一化数据,保证以方差作为选取划分维度方式的合理性
% 这里每个维度都是01的随机数因此不用再归一化
kd_tree = kd_tree_build(data);

if DATA_DIM == 2
    figure;
    hold on
    Limit_X = [0, 1];
    Limit_Y = [0, 1];
    plot_kd_tree(gca, kd_tree, Limit_X, Limit_Y);
end

% 目标点
target = rand(1, DATA_DIM);
if DATA_DIM == 2
    plot(target(1), target(2), '*');
end

% 比较kd-tree搜索和暴力搜索的用时
time_kd = 0;
time_simple = 0;
for i = 1: 10000
    tic
    res_kd = kd_tree_search(kd_tree, target);
    elapsedTime = toc;
    time_kd = time_kd + elapsedTime;
    
    tic
    res_simple = simple_search(data, target);
    elapsedTime = toc;
    time_simple = time_simple + elapsedTime;
    
    assert(sqrt(sum((res_kd - res_simple).^2)) < 0.0000001, 'error:kd-tree搜索和暴力搜索结果不一致');
end

disp(['kd-tree搜索用时' num2str(time_kd) '秒']);
disp(['暴力搜索用时' num2str(time_simple) '秒']);

%% 根据data建立kd_tree
function kd_tree = kd_tree_build(data)
kd_tree = [];
kd_tree_build_recurse(data, []);
    % 递归子函数
    function index_cur = kd_tree_build_recurse(data_cur, index)
        index_cur = [];
        if size(data_cur, 1) == 0
            return
        end
        
        index_cur = length(kd_tree) + 1;
        kd_tree(index_cur).Dim = [];
        kd_tree(index_cur).Cut = [];
        kd_tree(index_cur).Value = [];
        kd_tree(index_cur).Father = index;
        kd_tree(index_cur).Left_son = [];
        kd_tree(index_cur).Right_son = [];
        
        % 选取方差最大的维度进行划分,分割值设置为该维度上的中位数
        data_var = var(data_cur, 0, 1);
        [~, choose_dim] = max(data_var);
        data_cur = sortrows(data_cur, choose_dim);
        kd_tree(index_cur).Dim = choose_dim;
        left_data = data_cur(1: ceil(size(data_cur, 1)/2)-1, :);
        right_data = data_cur(ceil(size(data_cur, 1)/2)+1: end, :);
        kd_tree(index_cur).Value = data_cur(ceil(size(data_cur, 1)/2), :);
        kd_tree(index_cur).Cut = kd_tree(index_cur).Value(choose_dim);
        % 递归进行左右子树的构建
        kd_tree(index_cur).Left_son = kd_tree_build_recurse(left_data, index_cur);
        kd_tree(index_cur).Right_son = kd_tree_build_recurse(right_data, index_cur);
    end
end

%% 在坐标区ax上绘制kd_tree
% 绘制函数只支持二维数据
function plot_kd_tree(ax, kd_tree, Limit_X, Limit_Y)
plot_kd_tree_recurse(1, Limit_X, Limit_Y);

    % 递归子函数
    % 绘制以kd_tree_cur为根节点的子树
    function plot_kd_tree_recurse(kd_tree_cur, Limit_X_cur, Limit_Y_cur)
        if isempty(kd_tree_cur)
            return
        end
        
        if ~isempty(kd_tree(kd_tree_cur).Dim)
            % 绘制分割线
            % 递归进行子树的绘制
            if kd_tree(kd_tree_cur).Dim == 1
                plot(ax, [kd_tree(kd_tree_cur).Cut, kd_tree(kd_tree_cur).Cut], Limit_Y_cur, 'k-');
                plot_kd_tree_recurse(kd_tree(kd_tree_cur).Left_son, [Limit_X_cur(1), kd_tree(kd_tree_cur).Cut], Limit_Y_cur);
                plot_kd_tree_recurse(kd_tree(kd_tree_cur).Right_son, [kd_tree(kd_tree_cur).Cut, Limit_X_cur(2)], Limit_Y_cur);
            end
            if kd_tree(kd_tree_cur).Dim == 2
                plot(ax, Limit_X_cur, [kd_tree(kd_tree_cur).Cut, kd_tree(kd_tree_cur).Cut], 'k-');
                plot_kd_tree_recurse(kd_tree(kd_tree_cur).Left_son, Limit_X_cur, [Limit_Y_cur(1), kd_tree(kd_tree_cur).Cut]);
                plot_kd_tree_recurse(kd_tree(kd_tree_cur).Right_son, Limit_X_cur, [kd_tree(kd_tree_cur).Cut, Limit_Y_cur(2)]);
            end
        end
        
        % 绘制数据点
        plot(ax, kd_tree(kd_tree_cur).Value(1), kd_tree(kd_tree_cur).Value(2), 'r.');
        text(ax, kd_tree(kd_tree_cur).Value(1)+0.01, kd_tree(kd_tree_cur).Value(2), num2str(kd_tree_cur));
    end
end

%% 查找最近邻的值
function res = kd_tree_search(kd_tree, target)
    dist_best = inf;
    node_best = sub_kd_tree_search(1);
    queue = [1, node_best];
    while ~isempty(queue)
        root_cur = queue(1, 1);
        node_cur = queue(1, 2);
        queue(1, :) = [];
        while 1
            dist_cur = sqrt(sum((target - kd_tree(node_cur).Value).^2));
            if dist_cur < dist_best
                dist_best = dist_cur;
                node_best = node_cur;
            end
            if node_cur ~= root_cur
                node_brother = get_brother(node_cur);
                if ~isempty(node_brother)
                    dist_temp = abs(target(kd_tree(kd_tree(node_cur).Father).Dim) - kd_tree(kd_tree(node_cur).Father).Cut);
                    if dist_temp < dist_best
                        queue(end+1, :) = [node_brother, sub_kd_tree_search(node_brother)];
                    end
                end
                node_cur = kd_tree(node_cur).Father;
            else
                break
            end
        end
    end
    res = kd_tree(node_best).Value;
    
    % 以当前节点为根节点一直搜索到叶节点
    function index = sub_kd_tree_search(index)
        while ~isempty(kd_tree(index).Left_son) || ~isempty(kd_tree(index).Right_son)
            if isempty(kd_tree(index).Left_son)
                index = kd_tree(index).Right_son;
                continue
            end
            if isempty(kd_tree(index).Right_son)
                index = kd_tree(index).Left_son;
                continue
            end
            if target(kd_tree(index).Dim) <= kd_tree(index).Cut
                index = kd_tree(index).Left_son;
            else
                index = kd_tree(index).Right_son;
            end
        end
    end

    % 获取节点的兄弟节点
    function brother_index = get_brother(index)
        brother_index = [];
        father_index = kd_tree(index).Father;
        if isempty(father_index)
            return
        end
        if kd_tree(father_index).Left_son == index
            brother_index = kd_tree(father_index).Right_son;
        else
            brother_index = kd_tree(father_index).Left_son;
        end
    end
end

%% 暴力搜索
function res = simple_search(data, target)
dist = inf;
res = 0;
for i = 1: size(data, 1)
    dist_temp = sqrt(sum((target - data(i,:)).^2));
    if dist_temp < dist
        dist = dist_temp;
        res = i;
    end
end
res = data(res, :);
return 
end
  • 6
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
以下是 MATLAB 中实现 KD 的示例代码: ```matlab function [idx, dist] = kdtree_search(X, Y, k) % X: 点集,每行表示一个点,共 n 行 % Y: 待查询点,1 行 m 列 % k: 近邻数 % idx: 最近的 k 个点的下标,1 行 k 列 % dist: 最近的 k 个点到 Y 的距离,1 行 k 列 n = size(X, 1); % 构建 KD root = build_kdtree(X, 1:n, 1); % 初始化最近的 k 个点和它们的距离 idx = zeros(1, k); dist = inf(1, k); % 搜索 KD search_kdtree(root, X, Y, k, idx, dist); end function root = build_kdtree(X, idxs, d) % X: 点集,每行表示一个点,共 n 行 % idxs: 点集 X 中点的下标,1 行 m 列 % d: 当前切分的维度 % root: KD 根节点 m = length(idxs); if m == 0 root = []; else % 找到中位数的下标 [~, i] = sort(X(idxs, d)); mid = ceil(m / 2); mid_idx = idxs(i(mid)); % 构建 KD 的左右子 root = struct('idx', mid_idx, 'left', [], 'right', []); root.left = build_kdtree(X, idxs(i(1:mid-1)), mod(d, size(X, 2)) + 1); root.right = build_kdtree(X, idxs(i(mid+1:end)), mod(d, size(X, 2)) + 1); end end function search_kdtree(root, X, Y, k, idx, dist) % root: KD 节点 % X: 点集,每行表示一个点,共 n 行 % Y: 待查询点,1 行 m 列 % k: 近邻数 % idx: 最近的 k 个点的下标,1 行 k 列 % dist: 最近的 k 个点到 Y 的距离,1 行 k 列 if isempty(root) return end % 计算当前节点到 Y 的距离 d = norm(X(root.idx, :) - Y); % 如果当前节点更近,则更新最近的 k 个点 if d < dist(k) idx(k) = root.idx; dist(k) = d; [~, i] = sort(dist); idx = idx(i); dist = dist(i); end % 计算待查询点在当前维度上的坐标 y = Y(mod(root.left.d - 1, size(Y, 2)) + 1); % 先查询离待查询点更近的子 if y < X(root.idx, root.left.d) search_kdtree(root.left, X, Y, k, idx, dist); search_kdtree(root.right, X, Y, k, idx, dist); else search_kdtree(root.right, X, Y, k, idx, dist); search_kdtree(root.left, X, Y, k, idx, dist); end end ``` 该代码实现了 KD 的构建和查询,可以通过调用 `kdtree_search` 函数来进行查询。其中,`X` 是点集,每行表示一个点;`Y` 是待查询点,1 行表示一个点;`k` 是近邻数,即要查询最近的 k 个点。函数返回值为最近的 k 个点的下标和它们到待查询点的距离。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值