KDTree的C++实现

KDTree原理:

请参考
1. k-d tree算法的研究
2. Python手撸机器学习系列(十一):KNN之kd树实现
完整代码: https://github.com/nnzzll/NaiveKDTree

C++实现

点的结构

template <typename T>
struct Point3D
{
    T x, y, z; 
    int index; // 在点的成员里记录该点索引,方便构造KDTree
    Point3D() : x(0), y(0), z(0), index(-1){};
    Point3D(T a, T b, T c) : x(a), y(b), z(c), index(-1){};
    Point3D(T a, T b, T c, int idx) : x(a), y(b), z(c), index(idx){};
    inline T &operator[](int i) { return i == 0 ? x : i == 1 ? y
                                                             : z; };
};

template <typename T>
struct Point2D
{
    T x, y;
    int index;
    Point2D() : x(0), y(0), index(-1){};
    Point2D(T a, T b) : x(a), y(b), index(-1){};
    Point2D(T a, T b, int idx) : x(a), y(b), index(idx){};

    inline T &operator[](int i) { return i == 0 ? x : y; };
};

KDTree结点的结构

struct KDNode
{
	int index; // 记录该结点保存的点的索引
	int axis; // 记录该结点二分的维度
	KDNode *left;
	KDNode *right;
	KDNode(int index, int axis, KDNode *left = nullptr, KDNode *right = nullptr)
	{
		this->index = index;
		this->axis = axis;
		this->left = left;
		this->right = right;
	}
};

KDTree的结构

template <class T>
class KDTree
{
private:
	int ndim;
	KDNode *root;
	KDNode *build(std::vector<T> &);
	std::set<int> visited; // 用于搜索时回溯
	std::stack<KDNode *> queueNode; // 记录搜索路径
	std::vector<T> m_data;

	void release(KDNode *);
	void printNode(KDNode *);
	int chooseAxis(std::vector<T> &);
	void dfs(KDNode *, T);
	// 点与点之间的距离
	inline double distanceT(KDNode *, T);
	inline double distanceT(int, T);
	// 点与超平面的距离
	inline double distanceP(KDNode *, T);
	// 检查父节点超平面是否在超球体中
	inline bool checkParent(KDNode *, T, double);

public:
	KDTree(std::vector<T> &, int);
	~KDTree();
	void Print();
	int findNearestPoint(T);
};

KDTree的构造函数

template <class T>
KDTree<T>::KDTree(std::vector<T> &data, int dim)
{
	ndim = dim;
	m_data = data; // 拷贝一份数据
	root = build(data); // 递归地构造二叉树
}

template <class T>
KDNode *KDTree<T>::build(std::vector<T> &data)
{
	if (data.empty())
		return nullptr;
	std::vector<T> temp = data;
	int mid_index = static_cast<int>(data.size() / 2); // 二分的索引
	int axis = data.size() > 1 ? chooseAxis(temp) : -1; // 根据每个维度的方差大小选择二分的维度,叶子结点无法二分,默认为-1
	std::sort(temp.begin(), temp.end(), [axis](T a, T b)
			  { return a[axis] < b[axis]; });
			  
	std::vector<T> leftData, rightData;
	leftData.assign(temp.begin(), temp.begin() + mid_index);
	rightData.assign(temp.begin() + mid_index + 1, temp.end());
	
	KDNode *leftNode = build(leftData);
	KDNode *rightNode = build(rightData);
	KDNode *rootNode;
	rootNode = new KDNode(temp[mid_index].index, axis, leftNode, rightNode);
	return rootNode;
}

最近邻搜索

参考[1]

template <class T>
int KDTree<T>::findNearestPoint(T pt)
{
	while (!queueNode.empty())
		queueNode.pop();
	double min_dist = DBL_MAX;
	int resNodeIdx = -1;
	dfs(root, pt);
	while (!queueNode.empty())
	{
		KDNode *curNode = queueNode.top();
		queueNode.pop();
		double dist = distanceT(curNode, pt);
		if (dist < min_dist)
		{
			min_dist = dist;
			resNodeIdx = curNode->index;
		}

		if (!queueNode.empty())
		{
			KDNode *parentNode = queueNode.top();
			int parentAxis = parentNode->axis;
			int parentIndex = parentNode->index;
			if (checkParent(parentNode, pt, min_dist))
			{
				if (m_data[curNode->index][parentNode->axis] < m_data[parentNode->index][parentNode->axis])
					dfs(parentNode->right, pt);
				else
					dfs(parentNode->left, pt);
			}
		}
	}
	return resNodeIdx;
}

template <class T>
void KDTree<T>::dfs(KDNode *node, T pt)
{
	if (node)
	{
		if (visited.find(node->index) != visited.end())
			return;
		queueNode.push(node);
		visited.insert(node->index);
		if (pt[node->axis] <= m_data[node->index][node->axis] && node->left)
			dfs(node->left, pt);
		else if (pt[node->axis] >= m_data[node->index][node->axis] && node->right)
			dfs(node->right, pt);
		// 若子树只有一个叶子节点,则不能按照维度上的值来判断进入子树的左子空间还是右子空间,会漏掉可能的近邻点
		else if ((node->left == nullptr) ^ (node->right == nullptr))
		{
			dfs(node->left, pt);
			dfs(node->right, pt);
		}
	}
}

测试

与VTK官方样例ClosestNPoints进行验证对比。
点云个数在1000以内,性能差不多和VTK的KdTree相当。

int main()
{
    int N = 500;
    // Create some random points
    vtkNew<vtkPointSource> pointSource;
    pointSource->SetNumberOfPoints(N);
    pointSource->Update();

    std::vector<Point3D<double>> datasets;
    vtkPoints *randPts = pointSource->GetOutput()->GetPoints();
    for (vtkIdType i = 0; i < N; i++)
    {
        double pts[3];
        randPts->GetPoint(i, pts);
        datasets.push_back(Point3D<double>(pts[0], pts[1], pts[2], i));
        // std::cout << pts[0] << "," << pts[1] << "," << pts[2] << std::endl;
    }

    auto t1 = std::chrono::duration_cast<std::chrono::milliseconds>(
                  std::chrono::system_clock::now().time_since_epoch())
                  .count();
    
    // Create the tree
    vtkNew<vtkKdTreePointLocator> pointTree;
    pointTree->SetDataSet(pointSource->GetOutput());
    pointTree->BuildLocator();

    // Find the k closest points to (0,0,0)
    unsigned int k = 1;
    vtkNew<vtkPointSource> testSource;
    testSource->SetNumberOfPoints(1);
    testSource->Update();
    double testPoint[3];
    testSource->GetOutput()->GetPoints()->GetPoint(0, testPoint);
    vtkNew<vtkIdList> result;
    std::cout << "Test Point: " << testPoint[0] << "," << testPoint[1] << "," << testPoint[2] << std::endl;

    pointTree->FindClosestNPoints(k, testPoint, result);

    for (vtkIdType i = 0; i < k; i++)
    {
        vtkIdType point_ind = result->GetId(i);
        double p[3];
        pointSource->GetOutput()->GetPoint(point_ind, p);
        std::cout << "Closest point " << i << ": Point " << point_ind << ": ("
                  << p[0] << ", " << p[1] << ", " << p[2] << ")" << std::endl;
    }
    auto t2 = std::chrono::duration_cast<std::chrono::milliseconds>(
                  std::chrono::system_clock::now().time_since_epoch())
                  .count();

    // Should return:
    // Closest point 0: Point 2: (-0.136162, -0.0276359, 0.0369441)

    // std::vector<Point2D<double>> datasets = {Point2D<double>(7, 2, 0),
    //                                       Point2D<double>(5, 4, 1),
    //                                       Point2D<double>(9, 6, 2),
    //                                       Point2D<double>(2, 3, 3),
    //                                       Point2D<double>(4, 7, 4),
    //                                       Point2D<double>(8, 1, 5)};
    KDTree<Point3D<double>> tree(datasets, 3);
    // tree.Print();
    std::cout << tree.findNearestPoint(Point3D<double>(testPoint[0], testPoint[1], testPoint[2])) << std::endl;
    auto t3 = std::chrono::duration_cast<std::chrono::milliseconds>(
                  std::chrono::system_clock::now().time_since_epoch())
                  .count();
    std::cout << "VTK Time:" << t2 - t1 << " ms" << std::endl;
    std::cout << "MY Time:" << t3 - t2 << " ms" << std::endl;
    return EXIT_SUCCESS;
}
Test Point: 0.117163,-0.205549,0.352397
Closest point 0: Point 474: (0.12327, -0.22358, 0.322906)
474
VTK Time:4 ms
MY Time:2 ms
  • 2
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 9
    评论
以下是一个简单的C++实现k-d tree的例子: ```c++ #include <iostream> #include <vector> #include <algorithm> using namespace std; struct Node { vector<int> point; Node* left; Node* right; Node(vector<int> p) : point(p), left(NULL), right(NULL) {} }; class KDTree { public: Node* root; KDTree() : root(NULL) {} void insert(vector<int> point) { root = insert(root, point, 0); } Node* insert(Node* node, vector<int> point, int depth) { if (node == NULL) { return new Node(point); } int k = point.size(); int cd = depth % k; if (point[cd] < node->point[cd]) { node->left = insert(node->left, point, depth + 1); } else { node->right = insert(node->right, point, depth + 1); } return node; } void search(vector<int> point) { Node* nearest = search(root, point, 0); cout << "Nearest point: "; for (int i = 0; i < nearest->point.size(); i++) { cout << nearest->point[i] << " "; } cout << endl; } Node* search(Node* node, vector<int> point, int depth) { if (node == NULL) { return NULL; } int k = point.size(); int cd = depth % k; if (point[cd] < node->point[cd]) { return search(node->left, point, depth + 1); } else if (point[cd] > node->point[cd]) { return search(node->right, point, depth + 1); } else { if (node->left == NULL && node->right == NULL) { return node; } Node* nearest = nearestNeighbor(node->left, point, depth + 1, node); if (distance(point, nearest->point) > abs(point[cd] - node->point[cd])) { nearest = nearestNeighbor(node->right, point, depth + 1, nearest); } return nearest; } } Node* nearestNeighbor(Node* node, vector<int> point, int depth, Node* best) { if (node == NULL) { return best; } if (distance(point, node->point) < distance(point, best->point)) { best = node; } int k = point.size(); int cd = depth % k; if (point[cd] < node->point[cd]) { best = nearestNeighbor(node->left, point, depth + 1, best); if (distance(point, best->point) > abs(point[cd] - node->point[cd])) { best = nearestNeighbor(node->right, point, depth + 1, best); } } else { best = nearestNeighbor(node->right, point, depth + 1, best); if (distance(point, best->point) > abs(point[cd] - node->point[cd])) { best = nearestNeighbor(node->left, point, depth + 1, best); } } return best; } int distance(vector<int> a, vector<int> b) { int d = 0; for (int i = 0; i < a.size(); i++) { d += (a[i] - b[i]) * (a[i] - b[i]); } return d; } }; int main() { KDTree tree; tree.insert({3, 6}); tree.insert({17, 15}); tree.insert({13, 15}); tree.insert({6, 12}); tree.insert({9, 1}); tree.search({10, 10}); return 0; } ``` 这个例子实现了一个简单的k-d tree,包括插入和搜索操作。在这个例子中,我们使用了一个结构体Node来表示k-d tree的节点,每个节点包含一个向量point,表示多维空间中的一个点。在插入操作中,我们根据当前节点的深度和向量的某一维的大小来决定将新的点插入到左子树还是右子树中。在搜索操作中,我们首先找到最近的节点,然后递归地搜索左右子树,直到找到最近的点。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值