《统计学习方法》C++实现kd tree

参考:https://www.cnblogs.com/90zeng/p/kdtree.html
作者写的非常好,我只是改动成了我习惯的格式,稍许小改动,感谢作者

#include <iostream>
#include <vector>
#include "kd_tree.hpp"
using namespace std;


int main()
{
    int data[6][2] = {{2,3},{5,4},{9,6},{4,7},{8,1},{7,2}};

    vector<vector<double> > train(6, vector<double>(2, 0));
    for (unsigned i = 0; i < 6; ++i)
        for (unsigned j = 0; j < 2; ++j)
            train[i][j] = data[i][j];

    auto* kdTree = new kd_tree;
    build_kd_tree(kdTree, train, 0);

    print_kd_tree(kdTree, 0);

    vector<double> goal;
    goal.push_back(3);
    goal.push_back(4.5);
    vector<double> nearest_neighbor = search_nearest_neighbor(goal, kdTree);
    auto beg = nearest_neighbor.begin();
    cout << "The nearest neighbor is: ";
    while(beg != nearest_neighbor.end()) cout << *beg++ << ",";
    cout << endl;
    return 0;
}
//
// Created by gu on 12/16/2020.
//

#include <iostream>
#include <vector>
#include <algorithm>
#include <cmath>
using namespace std;

class kd_tree{
public:
    vector<double> root;
    kd_tree* parent;
    kd_tree* left_child;
    kd_tree* right_child;

    //默认构造函数
    kd_tree() : parent(nullptr), left_child(nullptr), right_child(nullptr){}

    //判断kd树是否为空
    bool is_empty() const
    {
        return root.empty();
    }

    //判断kd树是否只是一个叶子结点
    bool is_leaf() const
    {
        return (!root.empty()) && right_child == nullptr && left_child == nullptr;
    }

    //判断是否是树的根结点
    bool is_root() const
    {
        return (!is_empty()) && parent == nullptr;
    }

    //判断该子kd树的根结点是否是其父kd树的左结点
    bool is_left() const
    {
        return parent->left_child->root == root;
    }

    //判断该子kd树的根结点是否是其父kd树的右结点
    bool is_right() const
    {
        return parent->right_child->root == root;
    }
};



/*
 * 转置一个矩阵,返回
 * * */
template<typename T>
vector<vector<T> > transpose(vector<vector<T> > Matrix)
{
    unsigned row = Matrix.size();
    unsigned col = Matrix[0].size();

    vector<vector<T> > Trans(col,vector<T>(row,0));
    for (unsigned i = 0; i < col; i++)
    {
        for (unsigned j = 0; j < row; ++j)
        {
            Trans[i][j] = Matrix[j][i];
        }
    }
    return Trans;
}

/*
 * 查找数组的中位数
 * */
template <typename T>
T find_middle_value(vector<T> vec)
{
    sort(vec.begin(),vec.end());
    auto pos = vec.size() / 2;
    return vec[pos];
}

//构建kd树
void build_kd_tree(kd_tree* tree, vector<vector<double> > data, unsigned depth)
{
    //样本的数量
    unsigned samples_num = data.size();

    //终止条件
    if (samples_num == 0)
    {
        return;
    }
    if (samples_num == 1)
    {
        tree->root = data[0];
        return;
    }

    //样本的维度
    unsigned dimension = data[0].size();
    vector<vector<double> > trans_data = transpose(data);
    
    //选择切分属性
    unsigned split_attribute = depth % dimension;
    vector<double> split_attribute_values = trans_data[split_attribute];
    
    //选择切分值
    double splitValue = find_middle_value(split_attribute_values);
    //cout << "splitValue" << splitValue  << endl;

    // 根据选定的切分属性和切分值,将数据集分为两个子集
    vector<vector<double>> subset1;
    vector<vector<double>> subset2;
    for (unsigned i = 0; i < samples_num; i++)
    {
        if (split_attribute_values[i] == splitValue && tree->root.empty())
            tree->root = data[i];
        else
        {
            if (split_attribute_values[i] < splitValue)
                subset1.push_back(data[i]);
            else
                subset2.push_back(data[i]);
        }
    }

    //子集递归调用buildkd_tree函数

    tree->left_child = new kd_tree;
    tree->left_child->parent = tree;

    tree->right_child = new kd_tree;
    tree->right_child->parent = tree;

    build_kd_tree(tree->left_child, subset1, depth + 1);
    build_kd_tree(tree->right_child, subset2, depth + 1);
}

//逐层打印kd树
void print_kd_tree(kd_tree *tree, unsigned depth)
{
    for (unsigned i = 0; i < depth; i++)
        cout << "\t";

    for (double j : tree->root)
        cout << j << ",";
    cout << endl;
    if (tree->left_child == nullptr && tree->right_child == nullptr )//叶子节点
        return;
    else //非叶子节点
    {
        if (tree->left_child != nullptr)
        {
            for (unsigned i = 0; i < depth + 1; i++)
                cout << "\t";
            cout << " left:";
            print_kd_tree(tree->left_child, depth + 1);
        }

        cout << endl;
        if (tree->right_child != nullptr)
        {
            for (unsigned i = 0; i < depth + 1; i++)
                cout << "\t";
            cout << "right:";
            print_kd_tree(tree->right_child, depth + 1);
        }
        cout << endl;
    }
}


//计算空间中两个点的距离
double measure_distance(vector<double> point1, vector<double> point2, unsigned method)
{
    if (point1.size() != point2.size())
    {
        cerr << "Dimensions don't match!!" ;
        return -1;
    }
    switch (method)
    {
        case 0://欧氏距离
        {
            double res = 0;
            for (vector<double>::size_type i = 0; i < point1.size(); i++)
            {
                res += pow((point1[i] - point2[i]), 2);
            }
            return sqrt(res);
        }
        case 1://曼哈顿距离
        {
            double res = 0;
            for (vector<double>::size_type i = 0; i < point1.size(); i++)
            {
                res += abs(point1[i] - point2[i]);
            }
            return res;
        }
        default:
        {
            cerr << "Invalid method!!" << endl;
            return -1;
        }
    }
}

//在kd tree中搜索目标点的最近邻
//输入:目标点, 已构造的kd树
//输出:目标点的最近邻
vector<double> search_nearest_neighbor(vector<double> goal, kd_tree *tree)
{
    /*第一步:在kd树中找出包含目标点的叶子结点:从根结点出发,
    递归的向下访问kd树,若目标点的当前维的坐标小于切分点的
    坐标,则移动到左子结点,否则移动到右子结点,直到子结点为
    叶结点为止,以此叶子结点为“当前最近点”
    */
    unsigned k = tree->root.size();                 //计算出数据的维数
    unsigned d = 0;                                 //维度初始化为0,即从第1维开始
    kd_tree* current_tree = tree;
    vector<double> current_nearest = current_tree->root;
    while(!current_tree->is_leaf())
    {
        unsigned index = d % k;                     //计算当前维度,深度
        if (current_tree->right_child->is_empty() || goal[index] < current_nearest[index])
        {
            current_tree = current_tree->left_child;
        }
        else
        {
            current_tree = current_tree->right_child;
        }
        d++;
    }
    current_nearest = current_tree->root;

    /*第二步:递归地向上回退, 在每个结点进行如下操作:
    (a)如果该结点保存的实例比当前最近点距离目标点更近,则以该例点为“当前最近点”
    (b)当前最近点一定存在于某结点一个子结点对应的区域,检查该子结点的父结点的另
    一子结点对应区域是否有更近的点(即检查另一子结点对应的区域是否与以目标点为球
    心、以目标点与“当前最近点”间的距离为半径的球体相交);如果相交,可能在另一
    个子结点对应的区域内存在距目标点更近的点,移动到另一个子结点,接着递归进行最
    近邻搜索;如果不相交,向上回退*/

    //当前最近邻与目标点的距离
    double current_distance = measure_distance(goal, current_nearest, 0);

    //如果当前子kd树的根结点是其父结点的左孩子,则搜索其父结点的右孩子结点所代表的区域,反之亦反
    kd_tree* search_district;
    if (current_tree->is_left())
    {
        if (current_tree->parent->right_child == nullptr)
            search_district = current_tree;
        else
            search_district = current_tree->parent->right_child;
    }
    else
    {
        if (current_tree->parent->left_child == nullptr)
            search_district = current_tree;
        else
            search_district = current_tree->parent->left_child;
    }

    //如果搜索区域对应的子kd树的根结点不是整个kd树的根结点,继续回退搜索
    while (search_district->parent != nullptr)
    {
        //搜索区域与目标点的最近距离
        double district_distance = abs(goal[(d + 1) % k] - search_district->parent->root[(d + 1) % k]);

        //如果“搜索区域与目标点的最近距离”比“当前最近邻与目标点的距离”短,表明搜索区域内可能存在距离目标点更近的点
        if (district_distance < current_distance )//&& !search_district->is_empty()
        {
            double parent_distance = measure_distance(goal, search_district->parent->root, 0);

            if (parent_distance < current_distance)
            {
                current_distance = parent_distance;
                current_tree = search_district->parent;
                current_nearest = current_tree->root;
            }
            if (!search_district->is_empty())
            {
                double root_distance = measure_distance(goal, search_district->root, 0);
                if (root_distance < current_distance)
                {
                    current_distance = root_distance;
                    current_tree = search_district;
                    current_nearest = current_tree->root;
                }
            }
            if (search_district->left_child != nullptr)
            {
                double left_distance = measure_distance(goal, search_district->left_child->root, 0);
                if (left_distance < current_distance)
                {
                    current_distance = left_distance;
                    current_tree = search_district;
                    current_nearest = current_tree->root;
                }
            }
            if (search_district->right_child != nullptr)
            {
                double right_distance = measure_distance(goal, search_district->right_child->root, 0);
                if (right_distance < current_distance)
                {
                    current_distance = right_distance;
                    current_tree = search_district;
                    current_nearest = current_tree->root;
                }
            }
        }//end if

        if (search_district->parent->parent != nullptr)
        {
            search_district = search_district->parent->is_left()?
                             search_district->parent->parent->right_child:
                             search_district->parent->parent->left_child;
        }
        else
        {
            search_district = search_district->parent;
        }
        d++;
    }//end while
    return current_nearest;
}

  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值