假设平面上有n个点,要求求解最近的两点之间的距离。
首先可以想到暴力法,遍历所有点求解其与其它点的最近距离,时间复杂度为O(n^2)。
但还有另外两种复杂度更低的方法,分治法和k-d树法。(道理我相信大多数人都懂,文章也很多,关键是代码实现,可以直接看代码)
分治法
分治法的详细介绍可以参考这两篇文章,原理都差不多。
计算几何 平面最近点对 nlogn分治算法 求平面中距离最近的两点
uva-10245-The Closest Pair Problem-分治算法
主要的思路就是,将点集分为左右两个部分,最小距离有三种情况,分别是两个点都在左半部分中,都在右半部分中,以及一个点在左半部分一个点在右半部分。第三种情况,可以通过先求解前两者的最小距离d,得到一个竖直带状区域,在此区域外的点肯定不是最小距离,并且可以证明每个点只需要和六个点进行比较(因为在每一边,两点之间的距离都必须大于等于d)。
主要是代码的实现部分和复杂度分析比较关键,详情可以见后面的第一个题。
k-d树法
详细的k-d树的介绍可以参考这两篇文章
KNN(三)–KD树详解及KD树最近邻算法
K-D TREE算法原理及实现
说的比较多,什么根据切分维度分割之类,不过应用在二维和三维空间,自然是非常简单的233,不需要这么麻烦,即先后按照x、y、z来分割空间,下面来实现一下二维空间的简单的k-d树的基本功能(能用来查询最近邻)。
k-d树的构建
依次选取x、y作为切分维度,取数据点在该维度的中值作为切分超平面,将中值左侧的数据点挂在左子树,右侧的数据点挂在右子树,递归的进行处理。一个例子:
k-d树的查询
查询过程:
递归进行,根据当前切分维度判断目标点所在空间,查找子树,然后根据已得到的最近距离,回溯到父节点时,判断是否需要进入另一空间查询。
代码可以参考最典型的平面最近点对的kd树法,做个题才是检验代码正确与否的最好方法。
题目
最典型的平面最近点对问题。
vijos1012 https://vijos.org/p/1012
分治的代码如下,影响复杂度的关键地方我做了注释:
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>
#include <cmath>
using namespace std;
struct point {
double x;
double y;
};
typedef vector<point> points;
double dis(point a, point b) {
return sqrt((a.x - b.x)*(a.x - b.x) + (a.y - b.y)*(a.y - b.y));
}
double findMinDis(const points& a,int left,int right) {
if (left == right) {
return 1e18;
}
if (left + 1 == right) {
return dis(a[left], a[right]);
}
int mid = (right + left) >> 1;
double dis1 = findMinDis(a, left, mid);//分治
double dis2 = findMinDis(a, mid, right);
double limit = min(dis1, dis2);//限定中间的带状区域
points between;
for (int i = left; i <= right; ++i) {//遍历所有点,求解在区域里的点,时间复杂度O(N),这里还可以优化为从中心向两端扩展,超出范围时break
if (abs(a[i].x - a[mid].x) < limit) {
between.push_back(a[i]);
}
}
sort(between.begin(), between.end(), [](point p1, point p2) {return p1.y < p2.y; });//将在区域里的点按y排序,最坏时间复杂度为O(NlogN)
double ans = limit;
for (int i =0; i < between.size(); ++i) {//遍历区域里的点,最坏时间复杂度为O(N)
for (int j = i + 1; j < i+7 && j < between.size(); ++j) {//对于每个点,需要查找的另外的点的数量都是个位数
ans = min(ans, dis(between[i], between[j]));
}
}
return ans;
}
int main() {
int n;
cin >> n;
points a;
for (int j = 0; j < n; ++j) {
double x, y;
cin >> x >> y;
a.push_back({ x,y });
}
sort(a.begin(), a.end(), [](point p1, point p2) {return p1.x < p2.x; });//首先将数据按x轴排序,时间复杂度O(NlogN)
double dis = findMinDis(a,0,a.size()-1);
printf("%.3f", dis);
system("pause");
return 0;
}
因此有T(n)=T(n/2)+nlog(n),故时间复杂度为O(n log(n) log(n))。另外可以通过归并排序的方式,避免每次按y坐标排序的nlog(n),将整体时间复杂度可以降低为nlog(n)。
k-d树法:
#include <iostream>
#include <algorithm>
#include <vector>
#include <cmath>
using namespace std;
const double INF = 1e10;
struct Point {
double x, y;
int id;//用来查询时避免同一个点求解到自身的距离(针对题目所设置)
};
typedef vector<Point> Points;
struct kdTreeNode {
int dim;//切分维度,0是x,1是y
kdTreeNode* left;
kdTreeNode* right;
Point p;
kdTreeNode(int d, Point val) :dim(d), p(val), left(NULL), right(NULL) {};
};
double dis(Point a, Point b) {
if (a.id == b.id) {
return INF;
}
return sqrt((a.x - b.x)*(a.x - b.x) + (a.y - b.y)*(a.y - b.y));
}
kdTreeNode* buildTree(Points& pts, int d,int left,int right) {
if (left > right) {
return NULL;
}
d %= 2;
int mid = (right - left) / 2 + left;
nth_element(pts.begin() + left, pts.begin() + mid, pts.begin() + right+1, [d](Point a, Point b) {return d == 0 ? a.x < b.x :a.y < b.y; });//找到中间点,并且左边小于中间点,右边大于(快速选择算法,时间复杂度O(n))
Point p = pts[mid];
kdTreeNode* root = new kdTreeNode(d, p);
root->left = buildTree(pts, d + 1, left, mid-1);
root->right = buildTree(pts,d + 1, mid + 1,right);
return root;
}
void query(kdTreeNode* t, Point target,double& minDis) {
if (t == NULL) {
return;
}
if (t->dim == 0) {
if (target.x < t->p.x) {
query(t->left, target, minDis);
}
else {
query(t->right, target, minDis);
}
minDis = min(minDis, dis(target, t->p));
if (minDis > abs(target.x-t->p.x)) {
if (target.x < t->p.x) {
query(t->right, target, minDis);
}
else {
query(t->left, target, minDis);
}
}
}
else {
if (target.y < t->p.y) {
query(t->left, target, minDis);
}
else {
query(t->right, target, minDis);
}
minDis = min(minDis, dis(target, t->p));
if (minDis > abs(target.y - t->p.y)) {
if (target.y < t->p.y) {
query(t->right, target, minDis);
}
else {
query(t->left, target, minDis);
}
}
}
}
int main() {
int n;
cin >> n;
Points a;
for (int j = 0; j < n; ++j) {
double x, y;
cin >> x >> y;
a.push_back({ x,y,j });
}
kdTreeNode* t = buildTree(a, 0, 0, a.size() - 1);
double ans = INF;
for (int i = 0; i < a.size(); i++) {
double minDis = INF;
query(t, a[i],minDis);
ans = min(ans, minDis);
}
printf("%.3f", ans);
system("pause");
return 0;
}
两个点集间的最近点(也是一道腾讯笔试题)
poj 3714 http://poj.org/problem?id=3714
https://www.acwing.com/problem/content/description/121/
分治的代码如下:
#pragma GCC optimize(2)
#pragma G++ optimize(2)
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>
#include <cmath>
using namespace std;
struct point {
double x, y;
int id;
};
typedef vector<point> points;
const double inf = 1e10;
bool cmpx(point a, point b) {
return a.x < b.x;
}
bool cmpy(point a, point b) {
return a.y < b.y;
}
double dis(point a, point b) {
if (a.id == b.id) {
return inf;
}
return sqrt((a.x - b.x)*(a.x - b.x) + (a.y - b.y)*(a.y - b.y));
}
double findMinDisHelper(const points& a, int left, int right) {
if (left >= right) {
return inf;
}
if (left + 1 == right) {
return dis(a[left], a[right]);
}
int mid = (right + left) >> 1;
double dis1 = findMinDisHelper(a, left, mid);
double dis2 = findMinDisHelper(a, mid+1, right);
double limit = min(dis1, dis2);
points between;
for (int i = left; i <= right; ++i) {
if (abs(a[i].x - a[mid].x) < limit) {
between.push_back(a[i]);
}
}
sort(between.begin(), between.end(), cmpy);
double ans = limit;
for (int i = 0; i < between.size(); ++i) {
for (int j = i - 1; j>=0 && abs(between[i].y - between[j].y) < limit; --j) {
ans = min(ans, dis(between[i], between[j]));
}
}
return ans;
}
double findMinDis(points a, points b) {
points all;
for (int i = 0; i < a.size(); i++) {
all.push_back(a[i]);
}
for (int i = 0; i < b.size(); i++) {
all.push_back(b[i]);
}
sort(all.begin(), all.end(), cmpx);
return findMinDisHelper(all, 0, all.size() - 1);
}
int main() {
ios::sync_with_stdio(false);
int t;
cin >> t;
for (int i = 0; i < t; i++) {
int n;
cin >> n;
points a;
points b;
for (int j = 0; j < n; j++) {
double x, y;
cin >> x >> y;
a.push_back({ x,y,1});
}
for (int j = 0; j < n; j++) {
double x, y;
cin >> x >> y;
b.push_back({ x,y,2});
}
double dis = findMinDis(a, b);
printf("%.3f\n", dis);
}
system("pause");
return 0;
}
不过吐槽下,这题在北大oj上会超时,但在acwing上可以过,应该是有的数据会使得此算法退化,比如两类点各在一边的情况,解决方式是通过将数据随机旋转一个角度。