题意:首先给出n(n<=50000)个点,每个点最多k(k<=10)维。然后是T组查询,每次查询给定一个点,求距离这个点最近的m个点。按照距离从小到大输出。
思路:KD树,讲解可以参考统计学习教材K近邻那章。我的代码是参照(http://blog.csdn.net/wxfwxf328/article/details/8158187)写的。
首先是建树,难点不多,用到了nth_element函数,正好符合我们的所需:将中间点固定,小于这个点的弄到左边去,大于的弄到右边去。这里的一个技巧是point结构体的小于号重构,里面居然可以用一个全局变量当做变量。
查询的时候一次将整个树遍历一遍(中间有大量剪枝)。
很多博客写的在k维kd树上查询最邻近点的最坏情况复杂度为O(k * N^(1-1/k)).
#include <cstdio>
#include <queue>
#include <algorithm>
#include <cstring>
using namespace std;
#define INF 0x3fffffff
#define clr(s,t) memset(s,t,sizeof(s));
#define N 50005
#define Q 10005
int k,n,m,T,idx,son[N<<2];
struct point{
int s[7];
bool operator<(const point &b)const{
return s[idx]<b.s[idx];
}
}p[N],base,kdt[N<<2];
struct node{
point pp;
double dis;
bool operator<(const node &b)const{
return dis < b.dis;
}
}res[12];
priority_queue<struct node> q;
double dist(point a,point b){
double res = 0;
for(int i = 0;i<k;i++)
res += (a.s[i]-b.s[i])*(a.s[i]-b.s[i]);
return res;
}
void build(int r,int a,int b,int d){
if(a>b)
return;
idx = d%k; //此次按照第几个维度来进行二分
int mid = (a+b)>>1;
son[r] = b;
son[r*2] = son[r*2+1] = -1; //如果两个儿子都没有结点,当前结点当然就是叶节点
nth_element(p+a, p+mid, p+b+1);
kdt[r] = p[mid];
build(r*2, a, mid-1, d+1);
build(r*2+1, mid+1, b, d+1);
}
void query(int r,int d){
int x,y,id = d%k,flag = 1;
if(son[r] == -1)
return;
node tmp; //表示当前子树根结点
tmp.pp = kdt[r];
tmp.dis = dist(tmp.pp,base);//根节点到待查结点的距离
x = r*2;
y = r*2+1;
if(tmp.pp.s[id] < base.s[id])//先去待查结点所在的那一边去查找(为了下文方便,永远是先找x)
swap(x,y);
query(x, d+1);
if(q.size()<m)//如果现在找到的结点还不到m个,那么将这个根节点加进去啦
q.push(tmp);
else{
if(q.top().dis > tmp.dis){//根节点到base的距离比目前找到的最大距离要小,那么根结点加进去
q.pop();
q.push(tmp);
}
if((tmp.pp.s[id]-base.s[id])*(tmp.pp.s[id]-base.s[id]) > q.top().dis)//这一步是剪枝,相当于教材里描述的画那个圆
flag = 0;
}
if(flag)
query(y, d+1);
}
int main(){
while(scanf("%d %d",&n,&k)!=EOF){
int i,j;
idx = 0;
for(i = 1;i<=n;i++)
for(j = 0;j<k;j++)
scanf("%d",&p[i].s[j]);
build(1,1,n,0);
scanf("%d",&T);
while(T--){
for(j = 0;j<k;j++)
scanf("%d",&base.s[j]);
scanf("%d",&m);
query(1,0);
for(i = 1;i<=m;i++){
res[i] = q.top();
q.pop();
}
printf("the closest %d points are:\n",m);
for(i = m;i>=1;i--){
for(j = 0;j<k-1;j++)
printf("%d ",res[i].pp.s[j]);
printf("%d\n",res[i].pp.s[k-1]);
}
}
}
return 0;
}