KDtree(bzoj2648)

版权声明:本文为博主原创文章,转载请注明源网址blog.csdn.net/leo_h1104 https://blog.csdn.net/Leo_h1104/article/details/51478875

看到学姐开了博客我也坐不住啦

这道题标准KDtree模板题,借鉴一下黄学长的代码轻松水过

建树:

KD树的结构类似于平衡树,对于已知点集,求出其在某一维度上排序后的中间点(可以用stl的nth_element),然后把空间一分为二,再对每个部分递归建树,返回子部分分割点的编号,作为当前部分分割点的左右儿子。关于维度选择,可以求方差最大的维度,然而为了简化,这里选择横竖交替分割

int build(int l,int r,bool dir)
{
    nthdir=dir;
    nth_element(p+l,p+mid,p+r+1,cmp);
    mn[mid][0]=mx[mid][0]=p[mid][0];
    mn[mid][1]=mx[mid][1]=p[mid][1];
    if(l<mid) lch[mid]=build(l,mid-1,dir^1);
    if(r>mid) rch[mid]=build(mid+1,r,dir^1);
    update(mid);
    return mid;
}
其中update函数维护了mn,mx两个数组,分别保存该点所分割的点集中最大最小坐标,换句话说,是能包裹这个点集的矩形的各边坐标

inline void update(int x)
{
    int ll=lch[x];
    int rr=rch[x];
    if(ll)
    {
        mn[x][0]=min(mn[ll][0],mn[x][0]);
        mn[x][1]=min(mn[ll][1],mn[x][1]);
        mx[x][0]=max(mx[ll][0],mx[x][0]);
        mx[x][1]=max(mx[ll][1],mx[x][1]);
    }
    if(rr)
    {
        mn[x][0]=min(mn[rr][0],mn[x][0]);
        mn[x][1]=min(mn[rr][1],mn[x][1]);
        mx[x][0]=max(mx[rr][0],mx[x][0]);
        mx[x][1]=max(mx[rr][1],mx[x][1]);
    }
}


插入不改变树的形态,直接判断插入点在分割点的哪边,如果存在子节点递归进去,否则创建子节点

void insert(int t,point P,bool dir)
{
    if(P[dir]>=p[t][dir])
    {
        if(rch[t]) insert(rch[t],P,dir^1);
        else {
            rch[t]=++n;
            p[n]=P;
            mn[n][0]=mx[n][0]=P[0];
            mn[n][1]=mx[n][1]=P[1];
        }
    }
    else
    {
        if(lch[t]) insert(lch[t],P,dir^1);
        else {
            lch[t]=++n;
            p[n]=P;
            mn[n][0]=mx[n][0]=P[0];
            mn[n][1]=mx[n][1]=P[1];
        }
    }
    update(t);
}
查询时对于每一个被分割的点集,先用到分割点的距离更新答案,然后判断查询点到两边矩形的最短距离是否已经小于答案,若小于,这个矩形内的点也有可能更新答案,则递归进入求解。为了效率最优,优先选择到查询点距离较小的矩形递归


inline int outd(int k,point P)
{
    int tmp=0;
    tmp+=max(0,mn[k][0]-P[0]);
    tmp+=max(0,mn[k][1]-P[1]);
    tmp+=max(0,P[0]-mx[k][0]);
    tmp+=max(0,P[1]-mx[k][1]);
    return tmp;
}
int ans=INF;
void query(int t,point P)
{
    ans=min(ans,dist(p[t],P));
    int dl=INF,dr=INF;
    if(lch[t])dl=outd(lch[t],P);
    if(rch[t])dr=outd(rch[t],P);
    if(dl<dr)
    {
        if(dl<ans) query(lch[t],P);
        if(dr<ans) query(rch[t],P);
    }
    else
    {
        if(dr<ans) query(rch[t],P);
        if(dl<ans) query(lch[t],P);
    }
}


完整代码:

/**************************************************************
    Problem: 2648
    User: Leo_h
    Language: C++
    Result: Accepted
    Time:14520 ms
    Memory:32060 kb
****************************************************************/
 
#include<cstdio>
#include<algorithm>
#include<cstdlib>
#include<cstring>
using namespace std;
int n,m;
#define maxn 1000000
#define INF 0x7f7f7f7f
#define maxdim 2
#define mid (l+r)/2
int mn[maxn][2];
int mx[maxn][2];
int lch[maxn];
int rch[maxn];
bool nthdir;
struct point
{
    int pos[maxdim];
    int& operator[](int x)
    {
        return pos[x];
    }
    point(int x,int y)
    {
        pos[0]=x;
        pos[1]=y;
    }
    point(){}
}p[maxn];
inline int read(int &d)
{
    char ch=getchar();
    int tmp=1;
    d=0;
    while(ch<'0'||ch>'9')
    {
        if(ch=='-')tmp=-1;
        ch=getchar();
    }
    while(ch>='0'&&ch<='9')
    {
        d=d*10+ch-'0';
        ch=getchar();
    }
    return d=d*tmp;
}
inline bool cmp(point a,point b)
{
    return a[nthdir]<b[nthdir];
}
inline int dist(point x,point y)
{
    return abs(x[0]-y[0])+abs(x[1]-y[1]);
}
inline void update(int x)
{
    int ll=lch[x];
    int rr=rch[x];
    if(ll)
    {
        mn[x][0]=min(mn[ll][0],mn[x][0]);
        mn[x][1]=min(mn[ll][1],mn[x][1]);
        mx[x][0]=max(mx[ll][0],mx[x][0]);
        mx[x][1]=max(mx[ll][1],mx[x][1]);
    }
    if(rr)
    {
        mn[x][0]=min(mn[rr][0],mn[x][0]);
        mn[x][1]=min(mn[rr][1],mn[x][1]);
        mx[x][0]=max(mx[rr][0],mx[x][0]);
        mx[x][1]=max(mx[rr][1],mx[x][1]);
    }
}
int build(int l,int r,bool dir)
{
    nthdir=dir;
    nth_element(p+l,p+mid,p+r+1,cmp);
    mn[mid][0]=mx[mid][0]=p[mid][0];
    mn[mid][1]=mx[mid][1]=p[mid][1];
    if(l<mid) lch[mid]=build(l,mid-1,dir^1);
    if(r>mid) rch[mid]=build(mid+1,r,dir^1);
    update(mid);
    return mid;
}
inline int outd(int k,point P)
/*this function get the distance from P 
to the cube made up of all points in k*/
/*if P is in the area divided by p[k],
this function will return 0*/
{
    int tmp=0;
    tmp+=max(0,mn[k][0]-P[0]);
    tmp+=max(0,mn[k][1]-P[1]);
    tmp+=max(0,P[0]-mx[k][0]);
    tmp+=max(0,P[1]-mx[k][1]);
    return tmp;
}
int ans=INF;
void query(int t,point P)
{
    ans=min(ans,dist(p[t],P));
    int dl=INF,dr=INF;
    if(lch[t])dl=outd(lch[t],P);
    if(rch[t])dr=outd(rch[t],P);
    if(dl<dr)
    {
        if(dl<ans) query(lch[t],P);
        if(dr<ans) query(rch[t],P);
    }
    else
    {
        if(dr<ans) query(rch[t],P);
        if(dl<ans) query(lch[t],P);
    }
}
void insert(int t,point P,bool dir)
{
    if(P[dir]>=p[t][dir])
    {
        if(rch[t]) insert(rch[t],P,dir^1);
        else {
            rch[t]=++n;
            p[n]=P;
            mn[n][0]=mx[n][0]=P[0];
            mn[n][1]=mx[n][1]=P[1];
        }
    }
    else
    {
        if(lch[t]) insert(lch[t],P,dir^1);
        else {
            lch[t]=++n;
            p[n]=P;
            mn[n][0]=mx[n][0]=P[0];
            mn[n][1]=mx[n][1]=P[1];
        }
    }
    update(t);
}
int main()
{
    read(n);
    read(m);
    for(int i=1;i<=n;i++)
        read(p[i][0]),read(p[i][1]);
    int root=build(1,n,0);
    while(m--)
    {
        int mode,x,y;
        read(mode);
        read(x);
        read(y);
        if(mode==1) insert(root,point(x,y),0);
        else {
            ans=INF;
            query(root,point(x,y));
            printf("%d\n",ans);
        }
    }
    return 0;
}



阅读更多

没有更多推荐了,返回首页