功能概述
给定当前位置position:[x,y]
,给定路径点序列path:[[x1,y1];[x2,y2];...;[xn,yn]],size=n*2
,返回path
中距离position
最近的路径点索引nearest_index
与最短距离nearest_dis
。另外还有1个可选参数,index_range: [range_upper,range_lower]
,表示nearest_index
所在的范围,其存在的意义有2个,1是可以减少搜索范围,从而减少计算量,2是为了应对交叉路径,防止最近点突变,具体情况如下图所示:
黑点为历史的位置轨迹,旁边的绿点为对应的最近路径点,蓝点为当前位置,大部分情况下(例如在自动驾驶轨迹跟踪任务中),我们希望找到的蓝点的最近点为旁边的绿点,但是如果仅用最近距离来判断则会返回红点,如果我们用index_range
限制最近点的范围,则可以避免这种情况。如果不设置index_range
,则默认返回最短距离的路径点。
项目github地址
完整工程见个人github仓库:
https://github.com/PeiyangKai/find_nearestPoint.git
后续持续更新,感兴趣的朋友帮忙点个star,跪谢!
具体实现
遍历暴力搜索
matlab
function [nearest_index, nearest_dis]= find_nearest_points_new(position,path,index_range)
% 输入参数
% position: [x,y]
% path: [x1,y1;x2,y2;...;xn,yn]
% index_range: [range_upper,range_lower] range_upper>1,range_lower<size(path,1)
% 输出参数
% nearest_index 最近点索引 nearest_dis 最近点距离
if nargin < 3 || isempty(index_range)
index_range = [1,size(path,1)]; % 如果没有提供 y 参数或者 y 参数为空,则使用默认值 10
end
start_index=index_range(1);
end_index=index_range(2);
new_path=path(start_index:end_index,1:2);
distance=(new_path(:,1)-position(1)).^2+(new_path(:,2)-position(2)).^2;
[nearest_dis,relative_index]=min(abs(distance));
nearest_dis=sqrt(nearest_dis);
nearest_index=start_index+relative_index-1;
end
python
import numpy as np
def find_nearest_points_new(position, path, index_range=None):
# 输入参数
# position: [x,y] 类型 np.ndarray
# path: [x1,y1;x2,y2;...;xn,yn] 类型 np.ndarray
# index_range: [range_upper,range_lower] range_upper>1,range_lower<size(path,1) 类型 np.ndarray
# 输出参数
# nearest_index 最近点索引 nearest_dis 最近点距离
if index_range is None:
index_range = [0, path.shape[0]]
start_index, end_index = index_range
new_path = path[start_index:end_index, :2]
distance = np.sum((new_path - position) ** 2, axis=1)
nearest_index = start_index + np.argmin(distance)
nearest_dis = np.sqrt(np.min(distance))
return nearest_index, nearest_dis
# 测试程序
if __name__=='__main__':
position=np.array([5.1,5])
path=np.array([[i,i] for i in range(10)])
nearest_index, nearest_dis=find_nearest_points_new(position,path)
print(nearest_index, nearest_dis)
c++
template <typename T>
struct Point2D
{
T x, y;
int index;
Point2D() : x(0), y(0), index(-1) {};
Point2D(T x, T y) : x(x), y(y), index(-1) {};
Point2D(T x, T y, int index) : x(x), y(y), index(index) {};
inline T& operator[](int i) {
return i == 0 ? x : y;
}
};
std::pair<int, double> find_nearest_points_traverse(Point2D<double> position, std::vector<Point2D<double>>& path, std::pair<int, int> index_range = { 0, -1 }) {
// 输入参数
// position: [x,y]
// path: [x1,y1;x2,y2;...;xn,yn]
// index_range: [range_upper,range_lower] range_upper>1,range_lower<size(path,1)
// 输出参数
// nearest_index 最近点索引 nearest_dis 最近点距离
if (index_range.second == -1) {
index_range.second = path.size();
}
int start_index = index_range.first;
int end_index = index_range.second;
std::vector<Point2D<double>> new_path(path.begin() + start_index, path.begin() + end_index);
double nearest_dis = std::numeric_limits<double>::max();
int nearest_index = start_index;
for (int i = 0; i < new_path.size(); i++) {
double distance = pow(new_path[i].x - position.x, 2) + pow(new_path[i].y - position.y, 2);
if (distance < nearest_dis) {
nearest_dis = distance;
nearest_index = start_index + i;
}
}
nearest_dis = sqrt(nearest_dis);
return std::make_pair(nearest_index, nearest_dis);
}
KDtree高效搜索
KDtree搜索是一种高效的搜索方式,
使用说明
使用KDtree搜索之前需要先构建KDtree,相对于搜索最近点,构建KDtree的耗时更长。因此KDtree搜索并不太适合使用index_range
参数,这是因为不同的index_range
意味着不同的数据,需要构造不同的KDtree,这将大大增加耗时。
如果需要使用index_range
参数,则更建议采用暴力遍历搜索,因为index_range
已经限制了遍历搜索的范围,已经能够大量减少计算量,而KDtree则需要每次重新构建KDtree。
但为了使KDtree的搜索方式也可以适应index_range
参数,采用了以下策略实现:
先搜索最近的10个点,检查是否存在点在index_range
内,如果存在,则返回最近的点,如果不存在,则再次搜到最近的100个点,再次检查是否存在点在index_range
内,不存在则搜索最近的1000个点,知道搜索到index_range
内的最近点。(注意:该方法在某些情况很耗时,远超过遍历搜索)
c++
c++版本中,我们自己实现了KDTree的类,包括了KDTree的构建,以及K近邻搜索,其中K近邻搜索使用了BoundedPQueue。
- BoundedPQueue.h(引用的作者Keith Schwarz的算法)
/**
* Author: Keith Schwarz (htiek@cs.stanford.edu)
*
* An implementation of the bounded priority queue abstraction.
* A bounded priority queue is in many ways like a regular priority
* queue. It stores a collection of elements tagged with a real-
* valued priority, and allows for access to the element whose
* priority is the smallest. However, unlike a regular priority
* queue, the number of elements in a bounded priority queue has
* a hard limit that is specified in the constructor. Whenever an
* element is added to the bounded priority queue such that the
* size exceeds the maximum, the element with the highest priority
* value will be ejected from the bounded priority queue. In this
* sense, a bounded priority queue is like a high score table for
* a video game that stores a fixed number of elements and deletes
* the least-important entry whenever a new value is inserted.
*
* When creating a bounded priority queue, you must specify the
* maximum number of elements to store in the queue as an argument
* to the constructor. For example:
*
* BoundedPQueue<int> bpq(15); // Holds up to fifteen values.
*
* The maximum size of the bounded priority queue can be obtained
* using the maxSize() function, as in
*
* size_t k = bpq.maxSize();
*
* Beyond these restrictions, the bounded priority queue behaves
* similarly to other containers. You can query its size using
* size() and check whether it is empty using empty(). You
* can enqueue an element into the bounded priority queue by
* writing
*
* bpq.enqueue(elem, priority);
*
* Note that after enqueuing the element, there is no guarantee
* that the value will actually be in the queue. If the queue
* is full and the new element's priority exceeds the largest
* priority in the container, it will not be added.
*
* You can dequeue elements from a bounded priority queue using
* the dequeueMin() function, as in
*
* int val = bpq.dequeueMin();
*
* The bounded priority queue also allows you to query the min
* and max priorities of the values in the queue. These values
* can be queried using the best() and worst() functions, which
* return the smallest and largest priorities in the queue,
* respectively.
*/
#ifndef BOUNDED_PQUEUE_INCLUDED
#define BOUNDED_PQUEUE_INCLUDED
#include <map>
#include <algorithm>
#include <limits>
#include <utility>
template <typename T>
class BoundedPQueue {
public:
// Constructor: BoundedPQueue(size_t maxSize);
// Usage: BoundedPQueue<int> bpq(15);
// --------------------------------------------------
// Constructs a new, empty BoundedPQueue with
// maximum size equal to the constructor argument.
///
explicit BoundedPQueue(std::size_t maxSize);
// void enqueue(const T& value, double priority);
// Usage: bpq.enqueue("Hi!", 2.71828);
// --------------------------------------------------
// Enqueues a new element into the BoundedPQueue with
// the specified priority. If this overflows the maximum
// size of the queue, the element with the highest
// priority will be deleted from the queue. Note that
// this might be the element that was just added.
void enqueue(const T& value, double priority);
// T dequeueMin();
// Usage: int val = bpq.dequeueMin();
// --------------------------------------------------
// Returns the element from the BoundedPQueue with the
// smallest priority value, then removes that element
// from the queue.
std::pair<T, double> dequeueMin();
// size_t size() const;
// bool empty() const;
// Usage: while (!bpq.empty()) { ... }
// --------------------------------------------------
// Returns the number of elements in the queue and whether
// the queue is empty, respectively.
std::size_t size() const;
bool empty() const;
// size_t maxSize() const;
// Usage: size_t queueSize = bpq.maxSize();
// --------------------------------------------------
// Returns the maximum number of elements that can be
// stored in the queue.
std::size_t maxSize() const;
// double best() const;
// double worst() const;
// Usage: double highestPriority = bpq.worst();
// --------------------------------------------------
// best() returns the smallest priority of an element
// stored in the container (i.e. the priority of the
// element that will be dequeued first using dequeueMin).
// worst() returns the largest priority of an element
// stored in the container. If an element is enqueued
// with a priority above this value, it will automatically
// be deleted from the queue. Both functions return
// numeric_limits<double>::infinity() if the queue is
// empty.
double best() const;
double worst() const;
private:
// This class is layered on top of a multimap mapping from priorities
// to elements with those priorities.
std::multimap<double, T> elems;
std::size_t maximumSize;
};
/** BoundedPQueue class implementation details */
template <typename T>
BoundedPQueue<T>::BoundedPQueue(std::size_t maxSize) {
maximumSize = maxSize;
}
// enqueue adds the element to the map, then deletes the last element of the
// map if there size exceeds the maximum size.
template <typename T>
void BoundedPQueue<T>::enqueue(const T& value, double priority) {
// Add the element to the collection.
elems.insert(std::make_pair(priority, value));
// If there are too many elements in the queue, drop off the last one.
if (size() > maxSize()) {
typename std::multimap<double, T>::iterator last = elems.end();
--last; // Now points to highest-priority element
elems.erase(last);
}
}
// dequeueMin copies the lowest element of the map (the one pointed at by
// begin()) and then removes it.
template <typename T>
std::pair<T, double> BoundedPQueue<T>::dequeueMin() {
// Copy the best value.
T value = elems.begin()->second;
double priority= elems.begin()->first;
// Remove it from the map.
elems.erase(elems.begin());
std::pair<T, double> result(value, priority);
return result;
}
// size() and empty() call directly down to the underlying map.
template <typename T>
std::size_t BoundedPQueue<T>::size() const {
return elems.size();
}
template <typename T>
bool BoundedPQueue<T>::empty() const {
return elems.empty();
}
// maxSize just returns the appropriate data member.
template <typename T>
std::size_t BoundedPQueue<T>::maxSize() const {
return maximumSize;
}
// The best() and worst() functions check if the queue is empty,
// and if so return infinity.
template <typename T>
double BoundedPQueue<T>::best() const {
return empty() ? std::numeric_limits<double>::infinity() : elems.begin()->first;
}
template <typename T>
double BoundedPQueue<T>::worst() const {
return empty() ? std::numeric_limits<double>::infinity() : elems.rbegin()->first;
}
#endif // BOUNDED_PQUEUE_INCLUDED
- KDtree.h
#pragma once
#include <iostream>
#include <algorithm>
#include <vector>
#include <set>
#include <stack>
#include "BoundedPQueue.h"
// https://blog.csdn.net/weixin_42694889/article/details/124753575?ops_request_misc=&request_id=&biz_id=102&utm_term=KDtree%20c++&utm_medium=distribute.pc_search_result.none-task-blog-2~all~sobaiduweb~default-3-124753575.142^v99^pc_search_result_base9&spm=1018.2226.3001.4187
template <typename T>
struct Point3D
{
T x, y, z;
int index;
Point3D() : x(0), y(0), z(0), index(-1) {};
Point3D(T x, T y, T z) : x(x), y(y), z(z), index(-1) {};
Point3D(T x, T y, T z, int index) : x(x), y(y), z(z), index(index) {};
inline T& operator[](int i) {
return i == 0 ? x : i == 1 ? y : z;
}
};
template <typename T>
struct Point2D
{
T x, y;
int index;
Point2D() : x(0), y(0), index(-1) {};
Point2D(T x, T y) : x(x), y(y), index(-1) {};
Point2D(T x, T y, int index) : x(x), y(y), index(index) {};
inline T& operator[](int i) {
return i == 0 ? x : y;
}
};
struct KDNode
{
int index; //该节点存储数据点的索引
int axis; //划分的维度
KDNode* left;
KDNode* right;
KDNode(int index, int axis, KDNode* left = nullptr, KDNode* right = nullptr) {
this->index = index;
this->axis = axis;
this->left = left;
this->right = right;
}
};
template <typename T>
class KDTree
{
private:
int ndim; //维度
KDNode* root;
KDNode* build(std::vector<T>&); //构建KDTree
std::set<int> visited; //记录回溯搜索时是否已访问
std::stack<KDNode*> stackNode;
std::vector<T> m_data;
void release(KDNode*);
void printNode(KDNode*);
int chooseAxis(std::vector<T>&);//选择划分的轴
void dfs(KDNode*, T);//深度优先搜索
//计算点之间的距离
inline double distanceT(KDNode*, T);
inline double distanceT(int, T);
//点于超平面的距离
inline double distanceP(KDNode*, T);
//判断父节点的超平面是否在超球体中
inline bool checkParent(KDNode*, T, double);
public:
int NodeNums;
KDTree(std::vector<T>&,int dim);
~KDTree();
void Print();
std::vector<std::pair<int, double>> findNearestPoint(T point, int k=1);
};
//构造函数
template <typename T> //T可以是Point3D,也可以是Point2D
KDTree<T>::KDTree(std::vector<T>& data, int dim) {
this->ndim = dim;
NodeNums = data.size();
this->m_data = data;//拷贝原始数据
root = build(data);
}
//析构函数
template <typename T>
KDTree<T> ::~KDTree() {
release(root);
}
template <typename T>
void KDTree<T>::Print() {
printNode(root);
}
template <typename T>
void KDTree<T>:: release(KDNode* node) {
if (node == nullptr) {
return;
}
release(node->left);
release(node->right);
delete node;
node = nullptr;
}
template <typename T>
KDNode* KDTree<T>:: build(std::vector<T>& data) {
if (data.empty()) return nullptr;
std::vector<T> temp = data;
int mid_index = static_cast<int>(data.size() / 2);//数据索引中间值
int axis = data.size() > 1 ? chooseAxis(temp) : -1; // 根据每个维度的方差大小选择二分的维度,叶子结点无法二分,默认为-1
std::sort(temp.begin(), temp.end(), [axis](T a, T b) {return a[axis] < b[axis]; }); //这里用了一个仿函数
std::vector<T> leftData, rightData;
leftData.assign(temp.begin(), temp.begin() + mid_index);
rightData.assign(temp.begin() + mid_index + 1, temp.end());
KDNode* leftNode = build(leftData);
KDNode* rightNode = build(rightData);
KDNode* root =new KDNode(temp[mid_index].index, axis, leftNode, rightNode);
return root;
}
template <typename T>
void KDTree<T>::printNode(KDNode* node) {
if (node) {
std::cout << "Index: " << node->index << "\tAxis: " << node->axis << std::endl;
printNode(node->left);
printNode(node->right);
}
}
template <typename T>
int KDTree<T>::chooseAxis(std::vector<T>& data) {
int axis = -1;
double max_var = -1;
for (int i = 0; i < ndim; i++) {
//计算均值
double mean = 0;
for (auto j : data) {
mean += static_cast<double>(j[i]);
}
mean = mean / static_cast<double>(data.size());
//计算方差
double var = 0;
for (auto j : data) {
var += (static_cast<double>(j[i]) - mean) *(static_cast<double>(j[i]) - mean);
}
var = var / static_cast<double>(data.size());
if (var > max_var) {
max_var = var;
axis = i;
}
}
return axis;
}
template <typename T>
inline double KDTree<T>::distanceT(KDNode* node, T point)
{
double dis = 0;
for (int i = 0; i < ndim; i++) {
dis += (m_data[node->index][i] - point[i]) * (m_data[node->index][i] - point[i]);
}
dis = sqrt(dis);
return dis;
}
template <typename T>
inline double KDTree<T>::distanceT(int index, T point) {
double dis = 0;
for (int i = 0; i < ndim; i++) {
dis += (m_data[index][i] - point[i]) * (m_data[index][i] - point[i]);
}
dis = sqrt(dis);
return dis;
}
template <class T>
double KDTree<T>::distanceP(KDNode* node, T point) {
int axis = node->axis;
double dis = m_data[node->index][axis] - point[axis];
return abs(dis);
}
template <class T>
bool KDTree<T>::checkParent(KDNode* node, T pt, double distT)
{
double dis = distanceP(node, pt);
return dis <= distT;
}
template <typename T>
std::vector<std::pair<int, double>> KDTree<T>::findNearestPoint(T point,int k) {
visited.clear();
while (!stackNode.empty())
{
stackNode.pop();
}//清空stackNode
BoundedPQueue<int> pQueue(k); // BPQ with maximum size k
double min_dist = DBL_MAX;
int resNodeIdx = -1;
dfs(root, point);
while (!stackNode.empty()) {
KDNode* curNode = stackNode.top();
stackNode.pop();
double dist = distanceT(curNode, point);
pQueue.enqueue(curNode->index, dist);
if (!stackNode.empty()) {
KDNode* parentNode = stackNode.top();
int parentAxis = parentNode->axis;
int parentIndex = parentNode->index;
if (pQueue.size() < pQueue.maxSize() || checkParent(parentNode, point, pQueue.worst()))//如果最短半径超球体与分割超平面相交了
{
if (m_data[curNode->index][parentAxis] < m_data[parentIndex][parentAxis]) {
dfs(parentNode->right, point);
}
else {
dfs(parentNode->left, point);
}
}
}
}
std::vector<std::pair<int,double>> NearestKPointIndex;
while (!pQueue.empty()) {
NearestKPointIndex.push_back(pQueue.dequeueMin());
}
return NearestKPointIndex;
}
template <class T>
void KDTree<T>::dfs(KDNode* node, T pt)
{
if (node) {
if (visited.find(node->index) != visited.end())
return;//访问过了
stackNode.push(node);
visited.insert(node->index);
if (pt[node->axis] <= m_data[node->index][node->axis] && node->left)
dfs(node->left,pt);
else if(pt[node->axis] > m_data[node->index][node->axis] && node->right)
{
dfs(node->right,pt);
}
else if ((node->left == nullptr) ^ (node->right == nullptr))
{
dfs(node->left, pt);
dfs(node->right, pt);
}
}
}
- main.cpp
#include <iostream>
#include <vector>
#include <chrono>
#include <fstream>
#include <sstream>
#include "KDTree.h"
using namespace std;
void read_path(std::vector<Point2D<double>>& data, string fileName) {
std::ifstream file(fileName);
std::string line;
Point2D<double> pt = {0,0,0};
int id = 0;
// 读取第一行(列名)但不处理
std::getline(file, line);
// 逐行读取文件
while (std::getline(file, line)) {
std::stringstream ss(line);
std::string cell;
std::vector<std::string> tokens;
// 逐个读取逗号分隔的单元格
while (std::getline(ss, cell, ',')) {
tokens.push_back(cell);
}
// 检查是否有足够的数据(至少包含 x 和 y)
if (tokens.size() >= 2) {
// 将 x 和 y 值转换为浮点数
pt.x = std::stod(tokens[0]);
pt.y= std::stod(tokens[1]);
pt.index = id++;
data.push_back(pt);
}
}
}
std::pair<int, double> find_nearest_points_traverse(Point2D<double> position, std::vector<Point2D<double>>& path, std::pair<int, int> index_range = { 0, -1 }) {
// 输入参数
// position: [x,y]
// path: [x1,y1;x2,y2;...;xn,yn]
// index_range: [range_upper,range_lower] range_upper>1,range_lower<size(path,1)
// 输出参数
// nearest_index 最近点索引 nearest_dis 最近点距离
if (index_range.second == -1) {
index_range.second = path.size();
}
int start_index = index_range.first;
int end_index = index_range.second;
std::vector<Point2D<double>> new_path(path.begin() + start_index, path.begin() + end_index);
double nearest_dis = std::numeric_limits<double>::max();
int nearest_index = start_index;
for (int i = 0; i < new_path.size(); i++) {
double distance = pow(new_path[i].x - position.x, 2) + pow(new_path[i].y - position.y, 2);
if (distance < nearest_dis) {
nearest_dis = distance;
nearest_index = start_index + i;
}
}
nearest_dis = sqrt(nearest_dis);
return std::make_pair(nearest_index, nearest_dis);
}
std::pair<int, double> find_nearest_points_KDTree(KDTree<Point2D<double>>& kdt,Point2D<double> position, std::pair<int, int> index_range = { 0, -1 })
{
// 输入参数
// position: [x,y]
// path: [x1,y1;x2,y2;...;xn,yn]
// index_range: [range_upper,range_lower] range_upper>1,range_lower<size(path,1)
// 输出参数
// nearest_index 最近点索引 nearest_dis 最近点距离
if (index_range.second == -1) {
index_range.second = kdt.NodeNums;
}
int searchKPoint = 10;
bool searchFlag = true;
while (searchFlag) {
if (searchKPoint > kdt.NodeNums) {
searchKPoint = kdt.NodeNums;
searchFlag = false;
}
auto nearestKPoint = kdt.findNearestPoint(position, searchKPoint);
for (auto pt : nearestKPoint) {
if (pt.first >= index_range.first && pt.first < index_range.second) {
return pt;
}
}
searchKPoint = searchKPoint * 10;
}
}
int main() {
//读取数据
std::vector<Point2D<double>> path;
read_path(path, "E:\\Desktop\\blog\\road.csv");
//搜索目标
Point2D<double> position = { -550.9,37.5 };
std::pair<int, int> index_range = { 0, 6500 };
std::pair<int, double> nearest_point;
//暴力遍历搜索
auto t1 = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count();
for (int i = 0; i < 1000; i++) {
nearest_point = find_nearest_points_traverse(position, path, index_range);
}
auto t2 = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count();
cout << "最近点索引: " << nearest_point.first << " 最近点距离: " << nearest_point.second << endl;
cout << "遍历搜索耗时: " << t2 - t1 << " ms" << endl;
//KD树搜索
auto t3 = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count();
int dim = 2;
KDTree<Point2D<double>> kdt(path, dim);
auto t4 = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count();
cout << "构建KD树耗时: " << t4 - t3 << " ms" << endl;
auto t5 = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count();
vector<pair<int, double>> nearestKPoint;
for (int i = 0; i < 1000; i++) {
nearest_point = find_nearest_points_KDTree(kdt,position, index_range);
}
auto t6 = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count();
cout << "最近点索引: " << nearest_point.first << " 最近点距离: " << nearest_point.second << endl;
cout << "KD树搜索耗时: " << t6 - t5 << " ms" << endl;
return 0;
}
在main.cpp中,对暴力遍历搜索和KDtree搜索算法的耗时进行了对比。
-
无
index_range
参数:KD树搜索快于遍历搜索
-
有
index_range
参数:遍历搜索快于KD树搜索
python
python版本我们采用scipy的库来导入kdtree模块。
import pandas as pd
import numpy as np
from scipy.spatial import KDTree
import time
# KDTree搜索
def find_nearest_points_KDTree(kdtree,point,index_range=None):
if index_range is None:
nearest_dist, nearest_idx = kdtree.query(point)
return nearest_dist, nearest_idx
else:
searchFlag = True;
k_neighbors =10
while(searchFlag):
if(k_neighbors > len(kdtree.data)):
k_neighbors = len(kdtree.data)
searchFlag=False
nearest_dist, nearest_idx = kdtree.query(point,k_neighbors)
for i in range(len(nearest_idx)):
if nearest_idx[i]>=index_range[0] and nearest_idx[i]<index_range[1]:
return nearest_dist[i], nearest_idx[i]
k_neighbors = k_neighbors*10
def find_nearest_points_traverse(path,position, index_range=None):
# 输入参数
# position: [x,y] 类型 np.ndarray
# path: [x1,y1;x2,y2;...;xn,yn] 类型 np.ndarray
# index_range: [range_upper,range_lower] range_upper>1,range_lower<size(path,1) 类型 np.ndarray
# 输出参数
# nearest_index 最近点索引 nearest_dis 最近点距离
if index_range is None:
index_range = [0, path.shape[0]]
start_index, end_index = index_range
new_path = path[start_index:end_index, :2]
distance = np.sum((new_path - position) ** 2, axis=1)
nearest_index = start_index + np.argmin(distance)
nearest_dis = np.sqrt(np.min(distance))
return nearest_dis, nearest_index
if __name__ == '__main__':
# 读取数据
df=pd.read_csv('road.csv',header=0)
path=df[['x','y']].values
# 构建KDtree
t1=time.time()
kdtree = KDTree(path)
t2=time.time()
print("kdtree构建耗时:",1000*(t2-t1),"ms")
# 查询点
point = [-550.9,37.5]
t3=time.time()
for i in range(1000):
nearest_dist, nearest_idx = find_nearest_points_KDTree(kdtree,point,index_range=[0,6500])
# nearest_dist, nearest_idx = find_nearest_points_KDTree(kdtree,point)
t4=time.time()
print("kdtree搜索耗时:",1000*(t4-t3),"ms")
print("Nearest point index:", nearest_idx, "Distance:", nearest_dist)
t5=time.time()
for i in range(1000):
nearest_dist, nearest_idx = find_nearest_points_traverse(path,point,index_range=[0,6500])
# nearest_dist, nearest_idx = find_nearest_points_traverse(path,point)
t6=time.time()
print("遍历搜索耗时:",1000*(t6-t5),"ms")
print("Nearest point index:", nearest_idx, "Distance:", nearest_dist)
在main函数中,对暴力遍历搜索和KDtree搜索算法的耗时进行了对比。
- 无
index_range
参数:KD树搜索快于遍历搜索
- 有
index_range
参数:遍历搜索快于KD树搜索