一、题目
二、解法
做了这道题才真正掌握了 k-d \text{k-d} k-d树板子,相比于二维,多维写法有修改。
首先是建树,需要采用循环维度构建的方式(特别好写)
然后是询问,先拿到树上点的维度(是那一维的中位数),然后看询问的点在这一维离左边还是右边更近,就去那边继续查询,维护一个优先队列,看另一边需不需要查询(在这一维询问点距离树上点的距离必须比优先队列最大距离小)
口胡可能难以理解,请看代码吧 q w q qwq qwq,换了写法跑的贼快。
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <queue>
using namespace std;
const int M = 50005;
const int inf = 2e9+1;
int read()
{
int x=0,flag=1;char c;
while((c=getchar())<'0' || c>'9') if(c=='-') flag=-1;
while(c>='0' && c<='9') x=(x<<3)+(x<<1)+(c^48),c=getchar();
return x*flag;
}
int n,m,k,opt,rt,cnt,mi[M][5],mx[M][5],ls[M],rs[M];
struct node
{
int x[5];
}a[M],b[15],v[M];
struct data
{
int val;node p;
bool operator < (const data &b) const
{
return val<b.val;
}
};
priority_queue<data> q;
int sqr(int x)
{
return x*x;
}
void up(int x)
{
for(int i=0;i<m;i++)
mi[x][i]=mx[x][i]=v[x].x[i];
if(ls[x])
for(int i=0;i<m;i++)
{
mx[x][i]=max(mx[x][i],mx[ls[x]][i]);
mi[x][i]=min(mi[x][i],mi[ls[x]][i]);
}
if(rs[x])
for(int i=0;i<m;i++)
{
mx[x][i]=max(mx[x][i],mx[rs[x]][i]);
mi[x][i]=min(mi[x][i],mi[rs[x]][i]);
}
}
int cmp(node a,node b)
{
return a.x[opt]<b.x[opt];
}
void build(int &x,int l,int r,int dep)
{
if(l>r) return ;
x=++cnt;
ls[x]=rs[x]=0;
int mid=(l+r)>>1;
opt=dep%m;
nth_element(a+l,a+mid,a+r+1,cmp);
v[x]=a[mid];
build(ls[x],l,mid-1,dep+1);
build(rs[x],mid+1,r,dep+1);
up(x);
}
int f(node a,node b)
{
int r=0;
for(int i=0;i<m;i++)
r+=sqr(a.x[i]-b.x[i]);
return r;
}
void query(int x,node y,int dep)
{
int di=f(v[x],y),id=dep%m,fg=0,l=ls[x],r=rs[x];
if(y.x[id]>=v[x].x[id]) swap(l,r);
if(l) query(l,y,dep+1);
if(q.size()<k)
{
q.push(data{di,v[x]});
fg=1;
}
else
{
if(di<q.top().val)
{
q.pop();
q.push(data{di,v[x]});
}
if(sqr(y.x[id]-v[x].x[id])<q.top().val)
fg=1;
}
if(r && fg) query(r,y,dep+1);
}
signed main()
{
while(~scanf("%d %d",&n,&m))
{
rt=cnt=0;
for(int i=1;i<=n;i++)
for(int j=0;j<m;j++)
a[i].x[j]=read();
build(rt,1,n,0);
int T=read();
while(T--)
{
node y;
while(!q.empty()) q.pop();
for(int i=0;i<m;i++) y.x[i]=read();
k=read();
query(rt,y,0);
printf("the closest %d points are:\n",k);
for(int i=1;i<=k;i++) b[i]=q.top().p,q.pop();
for(int i=k;i>=1;i--,puts(""))
for(int j=0;j<m;j++)
{
if(j==0) printf("%d",b[i].x[0]);
else printf(" %d",b[i].x[j]);
}
}
}
}