KD-Tree 开源实现代码
- 添加头文件 #include <KDTree.hpp>
- 测试代码及源文件 下载链接(OpenCV version: 4.5.1)
- hpp 文件名为 “KDTree.hpp”
#ifndef __KDTREE_H__
#define __KDTREE_H__
#include <algorithm>
#include <cmath>
#include <exception>
#include <functional>
#include <numeric>
#include <vector>
namespace obstarcalib::caliblink {
template <typename PointT, int DIM = 2, typename Allocator = std::allocator<PointT>>
class KDTree {
public:
KDTree()
: root_(nullptr) {};
KDTree(const std::vector<PointT, Allocator>& points)
: root_(nullptr)
{
build(points);
}
~KDTree() { clear(); }
void build(const std::vector<PointT, Allocator>& points)
{
clear();
points_ = points;
std::vector<int> indices(points.size());
std::iota(std::begin(indices), std::end(indices), 0);
root_ = buildRecursive(indices.data(), (int)points.size(), 0);
}
void clear()
{
clearRecursive(root_);
root_ = nullptr;
points_.clear();
}
bool validate() const
{
try {
validateRecursive(root_, 0);
} catch (const Exception&) {
return false;
}
return true;
}
int nnSearch(const PointT& query, double* minDist = nullptr) const
{
int guess;
double _minDist = std::numeric_limits<double>::max();
nnSearchRecursive(query, root_, &guess, &_minDist);
if (minDist)
*minDist = _minDist;
return guess;
}
std::vector<int> knnSearch(const PointT& query, int k) const
{
KnnQueue queue(k);
knnSearchRecursive(query, root_, queue, k);
std::vector<int> indices(queue.size());
for (size_t i = 0; i < queue.size(); i++)
indices[i] = queue[i].second;
return indices;
}
std::vector<int> radiusSearch(const PointT& query, double radius) const
{
std::vector<int> indices;
radiusSearchRecursive(query, root_, indices, radius);
return indices;
}
private:
struct Node {
int idx;
Node* next[2];
int axis;
Node()
: idx(-1)
, axis(-1)
{
next[0] = next[1] = nullptr;
}
};
class Exception : public std::exception {
using std::exception::exception;
};
template <class T, class Compare = std::less<T>>
class BoundedPriorityQueue {
public:
BoundedPriorityQueue() = delete;
BoundedPriorityQueue(size_t bound)
: bound_(bound)
{
elements_.reserve(bound + 1);
};
void push(const T& val)
{
auto it = std::find_if(std::begin(elements_), std::end(elements_),
[&](const T& element) { return Compare()(val, element); });
elements_.insert(it, val);
if (elements_.size() > bound_)
elements_.resize(bound_);
}
const T& back() const { return elements_.back(); };
const T& operator[](size_t index) const { return elements_[index]; }
size_t size() const { return elements_.size(); }
private:
size_t bound_;
std::vector<T> elements_;
};
using KnnQueue = BoundedPriorityQueue<std::pair<double, int>>;
Node* buildRecursive(int* indices, int npoints, int depth)
{
if (npoints <= 0)
return nullptr;
const int axis = depth % DIM;
const int mid = (npoints - 1) / 2;
std::nth_element(indices, indices + mid, indices + npoints, [&](int lhs, int rhs) {
return points_[lhs][axis] < points_[rhs][axis];
});
Node* node = new Node();
node->idx = indices[mid];
node->axis = axis;
node->next[0] = buildRecursive(indices, mid, depth + 1);
node->next[1] = buildRecursive(indices + mid + 1, npoints - mid - 1, depth + 1);
return node;
}
void clearRecursive(Node* node)
{
if (node == nullptr)
return;
if (node->next[0])
clearRecursive(node->next[0]);
if (node->next[1])
clearRecursive(node->next[1]);
delete node;
}
void validateRecursive(const Node* node, int depth) const
{
if (node == nullptr)
return;
const int axis = node->axis;
const Node* node0 = node->next[0];
const Node* node1 = node->next[1];
if (node0 && node1) {
if (points_[node->idx][axis] < points_[node0->idx][axis])
throw Exception();
if (points_[node->idx][axis] > points_[node1->idx][axis])
throw Exception();
}
if (node0)
validateRecursive(node0, depth + 1);
if (node1)
validateRecursive(node1, depth + 1);
}
static double distance(const PointT& p, const PointT& q)
{
double dist = 0;
for (size_t i = 0; i < DIM; i++)
dist += (p[i] - q[i]) * (p[i] - q[i]);
return std::sqrt(dist);
}
void nnSearchRecursive(const PointT& query, const Node* node, int* guess, double* minDist) const
{
if (node == nullptr)
return;
const PointT& train = points_[node->idx];
const double dist = distance(query, train);
if (dist < *minDist) {
*minDist = dist;
*guess = node->idx;
}
const int axis = node->axis;
const int dir = query[axis] < train[axis] ? 0 : 1;
nnSearchRecursive(query, node->next[dir], guess, minDist);
const double diff = fabs(query[axis] - train[axis]);
if (diff < *minDist)
nnSearchRecursive(query, node->next[!dir], guess, minDist);
}
void knnSearchRecursive(const PointT& query, const Node* node, KnnQueue& queue, int k) const
{
if (node == nullptr)
return;
const PointT& train = points_[node->idx];
const double dist = distance(query, train);
queue.push(std::make_pair(dist, node->idx));
const int axis = node->axis;
const int dir = query[axis] < train[axis] ? 0 : 1;
knnSearchRecursive(query, node->next[dir], queue, k);
const double diff = fabs(query[axis] - train[axis]);
if ((int)queue.size() < k || diff < queue.back().first)
knnSearchRecursive(query, node->next[!dir], queue, k);
}
void radiusSearchRecursive(const PointT& query, const Node* node, std::vector<int>& indices, double radius) const
{
if (node == nullptr)
return;
const PointT& train = points_[node->idx];
const double dist = distance(query, train);
if (dist < radius)
indices.push_back(node->idx);
const int axis = node->axis;
const int dir = query[axis] < train[axis] ? 0 : 1;
radiusSearchRecursive(query, node->next[dir], indices, radius);
const double diff = fabs(query[axis] - train[axis]);
if (diff < radius)
radiusSearchRecursive(query, node->next[!dir], indices, radius);
}
Node* root_;
std::vector<PointT, Allocator> points_;
};
}
#endif
#include <iostream>
#include <KDTree.hpp>
#include <opencv2/opencv.hpp>
using namespace std;
using namespace obstarcalib::caliblink;
using namespace cv;
void testMyKDTreeNN()
{
float query[6][2];
for (int i = 0; i < 6; i++)
{
query[i][0] = i; query[i][1] = i + 1;
}
vector<float *> v1 = {query, query + 6};
KDTree<float *> kdTree(v1);
float train[2] = {3.6, 7.1};
double min = 0;
double *minDist = &min;
int index = kdTree.nnSearch(train, minDist);
cout << index << "---" << *minDist << endl;
}
void testMyKDTreeKNN()
{
float query[6][2];
for (int i = 0; i < 6; i++)
{
query[i][0] = i; query[i][1] = i + 1;
}
vector<float *> v1 = {query, query + 6};
KDTree<float *> kdTree(v1);
float train[2] = {3.6, 7.1};
int k = 2;
vector<int> index = kdTree.knnSearch(train, k);
cout << index[0] << endl;
}
void testOpenCVKDTreeKNN()
{
vector<Point2f> features;
for (int i = 0; i < 6; i++)
features.push_back({i, i + 1});
Mat source = Mat(features).reshape(1);
source.convertTo(source, CV_32F);
cv::flann::KDTreeIndexParams indexParams(1);
cv::flann::Index kdtree(source, indexParams);
int queryNum = 2;
vector<float> vecQuery(2);
vector<int> vecIndex(queryNum);
vector<float> vecDist(queryNum);
cv::flann::SearchParams params(32);
vecQuery = {3.6, 7.1};
kdtree.knnSearch(vecQuery, vecIndex, vecDist, queryNum, params);
cout << vecIndex[0] << "---" << sqrt(vecDist[0]) << endl;
}
int main(int argc, char **argv)
{
testMyKDTreeNN();
testMyKDTreeKNN();
testOpenCVKDTreeKNN();
system("pause");
return 0;
}
总结
- 开源 KD-Tree 实现代码数据结构必须可以使用 [ ] 的访问方式,并且knnSearch()接口没有返回匹配点距离,可自行修改源码
- OpenCV KD-Tree 的 KNN(K-Nearest Neighbors) 匹配速度优于开源 KD-Tree,但获得 NN (Nearest Neighbor),使用开源 KD-Tree 更快