洛谷 P4178 Tree(点分治模板)

点分治用于解决与树上路径有关的统计问题。

点分治的思想类似一般的分治:对于一棵树,一方面处理跨越不同子树的路径,即经过根的路径(相当于分治后合并时考虑跨越不同区间的影响),另一方面递归到各个子树中处理子树内的路径(相当于分治中考虑区间内的问题)。

而实际处理中,我们希望递归层数尽量小。因此,我们在处理一棵无根树时,选取的根要满足:其最大子树的大小尽量小。这个根称为重心。

总结起来,点分治流程如下:

1.找重心。

怎么找?

dfs一遍,计算每个点为根时其各个子树的大小,比较其最大子树的大小与之前找到最优点的最大子树的大小。如果更小,则更新最优点。

代码:

void getroot(int u,int fa){
    size[u]=1,F[u]=0;//size[u]为u为根的树的大小,F[u]为u为根的树的最大子树大小
    for(int i=head[u];i!=-1;i=G[i].nxt){
        int v=G[i].to;
        if(v==fa||vis[v]) continue;
        getroot(v,u);
        size[u]+=size[v];
        F[u]=max(F[u],size[v]);
    }
    F[u]=max(F[u],sum-size[u]);//sum为当前整棵树的大小,这里是考虑以u的父亲为根的子树
    if(F[u]<F[root]) root=u;
}

2.处理经过根的路径。要注意的是,我们处理时会遇到路径两端在同一子树内的路径,这种路径统计时需要删去(容斥原理)。

3.标记根节点(相当于删去)。

4.递归处理各个子树。

以本题为例:

1.找重心。

2.dfs求出当前树中各个结点到根的距离,然后将所有距离排序,统计符合条件的点对数,具体方法如下:
设由小到大排好序的所有距离为Q[i],两个指针l和r,初始时l+1为在开头,r为末尾。
如果Q[l+1]+Q[r]k,那么对于一端为结点l+1的路径,另一端为编号小于等于r且不等于l+1的结点都符合要求,而其中小于l+1的部分都已经统计过了,所以ans+=rl1,之后l++
否则,r,直到满足上述要求。
但这样得到的还不是正确答案。我们将所有结点放在一起统计,就可能存在路径的两端都来自同一子树的情况,如下图:

这里写图片描述

在上述算法中,如果路径3——2——1——2——4的长度小于等于k,那么它就会被我们统计到。显然,这种路径是不符合要求的。如何将其删去?
我们可以在以2号点为根的子树中进行一次上述统计,其中计算所有点到根2号点的距离时都加上边1——2的长度。这样我们计算出的距离实际是各个点到1号点的距离,因此统计出的路径也是前面提到的需要删去的路径。对各个子树都进行一遍这样的操作,从答案中减去即可。

3.标记根节点(相当于删去)。

4.递归处理各个子树。

讨论时间复杂度:

可以发现,递归层数最多为logn,因为每次找到重心后,最大子树一定不超过整棵树大小的一半。每层递归中dfs一遍所有子树为O(n),排序为O(nlogn)。因此总时间复杂度O(nlog2n)

代码:

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int maxn=40010;
int n,k,ans;
int root,sum,size[maxn],F[maxn],dis[maxn];
bool vis[maxn];
int Q[maxn],l,r;

int head[maxn],Ecnt;
struct edge{
    int to,nxt,w;
}G[maxn*2];

inline void addE(int u,int v,int w){
    G[Ecnt]=(edge){v,head[u],w};
    head[u]=Ecnt++;
    G[Ecnt]=(edge){u,head[v],w};
    head[v]=Ecnt++;
}

void getroot(int u,int fa){
    size[u]=1,F[u]=0;
    for(int i=head[u];i!=-1;i=G[i].nxt){
        int v=G[i].to;
        if(v==fa||vis[v]) continue;
        getroot(v,u);
        size[u]+=size[v];
        F[u]=max(F[u],size[v]);
    }
    F[u]=max(F[u],sum-size[u]);
    if(F[u]<F[root]) root=u;
}

void getdis(int u,int fa){
    Q[++r]=dis[u];
    for(int i=head[u];i!=-1;i=G[i].nxt){
        int v=G[i].to;
        if(v==fa||vis[v]) continue;
        dis[v]=dis[u]+G[i].w;
        getdis(v,u);
    }
}

int calc(int u,int d){
    int ret=0;
    l=r=0;
    dis[u]=d;
    getdis(u,0);
    sort(Q+1,Q+r+1);
    while(l<r){
        if(Q[l+1]+Q[r]<=k) ret+=r-l-1,++l;
        else --r;
    }
    return ret;
}

void solve(int u){
    ans+=calc(u,0);
    vis[u]=1;
    for(int i=head[u];i!=-1;i=G[i].nxt){
        int v=G[i].to;
        if(vis[v]) continue;
        ans-=calc(v,G[i].w);
        sum=size[v],root=0;
        getroot(v,0);
        solve(root);
    }
}

int main(){
    memset(head,-1,sizeof(head));
    scanf("%d",&n);
    for(int i=1;i<n;i++){
        int u,v,w;
        scanf("%d%d%d",&u,&v,&w);
        addE(u,v,w);
    }
    scanf("%d",&k);
    sum=F[0]=n,root=0;
    getroot(1,0);
    solve(root);
    printf("%d\n",ans);
    return 0;
}
阅读更多
想对作者说点什么? 我来说一句

没有更多推荐了,返回首页