kd tree原理详解: KD-树介绍
代码模板和部分题借鉴:kd-tree小结
例题一: P4475 巧克力王国
思路:根据巧克力a和可可b建立kd-tree,每个节点维护子树的最大的x,y以及最小的x,y,一个子树的权值和sum,每次查询,如果四对极值点 x y满足 x*a+y*b<h,那么这颗子树都可以对答案产生贡献,直接返回sum,否则再判断该节点是否满足条件以及两个儿子是否满足条件,模板题。
#include<cstdio>
#include<algorithm>
using namespace std;
typedef long long ll;
const int maxn=5e5+10;
int n,m,root,cur;
ll a,b,c;
struct P
{
int d[2],mx[2],mn[2],ls,rs,v;
ll sum;
int& operator[](int x){return d[x];}
friend bool operator<(P a,P b){return a[cur]<b[cur];}
}p[maxn];
struct kdtree
{
P t[maxn];
void up(int k)
{
int ls=t[k].ls,rs=t[k].rs;
for(int i=0;i<2;i++)
{
t[k].mn[i]=t[k].mx[i]=t[k][i];
if(ls)t[k].mn[i]=min(t[k].mn[i],t[ls].mn[i]);
if(rs)t[k].mn[i]=min(t[k].mn[i],t[rs].mn[i]);
if(ls)t[k].mx[i]=max(t[k].mx[i],t[ls].mx[i]);
if(rs)t[k].mx[i]=max(t[k].mx[i],t[rs].mx[i]);
}
t[k].sum=t[ls].sum+t[rs].sum+t[k].v;
}
int build(int l,int r,int now)
{
cur=now;
int mid=(l+r)/2;
nth_element(p+l,p+mid,p+r+1);
t[mid]=p[mid];
if(l<mid)t[mid].ls=build(l,mid-1,!now);
if(r>mid)t[mid].rs=build(mid+1,r,!now);
up(mid);
return mid;
}
int check(ll x,ll y)
{
return a*x+b*y<c;
}
int cal(P a)
{
int tmp=0;
tmp+=check(a.mx[0],a.mx[1]);
tmp+=check(a.mx[0],a.mn[1]);
tmp+=check(a.mn[0],a.mx[1]);
tmp+=check(a.mn[0],a.mn[1]);
return tmp;
}
ll qu(int o)
{
int tmp=cal(t[o]);
if(tmp==4)
return t[o].sum;
if(!tmp)return 0;
ll res=0;
if(check(t[o][0],t[o][1]))
res+=t[o].v;
if(t[o].ls)res+=qu(t[o].ls);
if(t[o].rs)res+=qu(t[o].rs);
return res;
}
}kd;
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
scanf("%d%d%d",&p[i][0],&p[i][1],&p[i].v);
root=kd.build(1,n,0);
while(m--)
{
scanf("%lld%lld%lld",&a,&b,&c);
printf("%lld\n",kd.qu(root));
}
}
例题二:Hide and Seek
枚举所有点,再用kd-tree求该点到平面内距离最大的点以及最小的点只差即可,kd-tree经典模板题
#include<cstdio>
#include<algorithm>
#include<cmath>
using namespace std;
const int maxn=5e5+10,inf=1e9;
int cur,n,root;
struct P
{
int d[2],mx[2],mn[2],ls,rs;
int& operator[](int x){return d[x];}
friend bool operator<(P a,P b){return a[cur]<b[cur];}
friend int dis(P x,P y) {return abs(x[0]-y[0])+abs(x[1]-y[1]);}
}p[maxn];
struct kdtree
{
P t[maxn],T;
int ans;
void up(int o)
{
int ls=t[o].ls,rs=t[o].rs;
for(int i=0;i<2;i++)
{
t[o].mn[i]=t[o].mx[i]=t[o][i];
if(ls)t[o].mn[i]=min(t[o].mn[i],t[ls].mn[i]);
if(rs)t[o].mn[i]=min(t[o].mn[i],t[rs].mn[i]);
if(ls)t[o].mx[i]=max(t[o].mx[i],t[ls].mx[i]);
if(rs)t[o].mx[i]=max(t[o].mx[i],t[rs].mx[i]);
}
}
int build(int l,int r,int now)
{
cur=now;
int mid=(l+r)/2;
nth_element(p+l,p+mid,p+r+1);
t[mid]=p[mid];
if(l<mid)t[mid].ls=build(l,mid-1,!now);
if(r>mid)t[mid].rs=build(mid+1,r,!now);
up(mid);
return mid;
}
int getv(int tp,P a)
{
int res=0;
for(int i=0;i<2;i++)
{
if(!tp)
res+=max(T[i]-a.mx[i],0)+max(a.mn[i]-T[i],0);
else
res+=max(abs(T[i]-a.mx[i]),abs(T[i]-a.mn[i]));
}
return res;
}
void qmin(int o)
{
if(dis(t[o],T))ans=min(ans,dis(t[o],T));
int ls=t[o].ls,rs=t[o].rs,t1=inf,t2=inf;
if(ls)t1=getv(0,t[ls]);
if(rs)t2=getv(0,t[rs]);
if(t1<t2)
{
if(t1<ans)qmin(ls);
if(t2<ans)qmin(rs);
}
else
{
if(t2<ans)qmin(rs);
if(t1<ans)qmin(ls);
}
}
void qmax(int o)
{
ans=max(ans,dis(t[o],T));
int ls=t[o].ls,rs=t[o].rs,t1=-inf,t2=-inf;
if(ls)t1=getv(1,t[ls]);
if(rs)t2=getv(1,t[rs]);
if(t1>t2)
{
if(t1>ans)qmax(ls);
if(t2>ans)qmax(rs);
}
{
if(t2>ans)qmax(rs);
if(t1>ans)qmax(ls);
}
}
int qu(int tp,int x,int y)
{
T[0]=x,T[1]=y;
if(tp)ans=-inf,qmax(root);
else ans=inf,qmin(root);
return ans;
}
}kd;
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;i++)
scanf("%d%d",&p[i][0],&p[i][1]);
root=kd.build(1,n,0);
int ans=inf;
for(int i=1;i<=n;i++)
{
int res=kd.qu(1,p[i][0],p[i][1])-kd.qu(0,p[i][0],p[i][1]);
ans=min(ans,res);
}
printf("%d\n",ans);
}