众所周知,树分治有两种:点分治和边分治。顾名思义,点分治是按点进行分治,为了让子问题的规模尽量小,我们通常选择重心做为分治的点;而边分治通常选择使得所分离出来的两棵子树的结点个数尽量平均的边。求重心和这样的边均可以用树DP的方法解决。通常点分治和边分治的递归层数为
O(logn)
,在实际运用中,边分治的层数经常达到
O(n)
效率低下……
以上均属于扯淡……
关于树的分治有一篇IOI的论文《分治算法在树的路径问题中的应用》,讲得比较详细。
先附上重心的代码
int sz[MAXN], sz2[MAXN], root;
void getroot(int u,int fa)
{
sz[u]=1, sz2[u]=0;
for(int p=adj[u], v;~p;p=edge[p].next)
if((v=edge[p].v)!=fa)
{
getroot(v,u);
sz[u]+=sz[v];
sz2[u]=max(sz2[u],sz[v]);
}
sz2[u]=max(sz2[u],sum-sz[u]);
if(sz2[u]<sz2[root])root=u;
}
1.poj1741
男人八题之一……
题目大意:有一颗节点为
n
的树,求有多少个点对
如何用点分治做?
考虑经过点
root
的点对
(u,v)
,其中
dis(u,v)<=k
。
为了防止退化为一条链,每次分治都以树的重心分治。
现在的问题是如何统计满足上述条件的点对的个数?
别忘了这是一个计数问题。
ans=以root为根的子树中dis(u,v)<=k的个数−以son(root)所在子树中dis(u′,v′)+2∗p−>w<=k
这里有一个经典的线性扫描的方法。
#include <iostream>
#include <cstdio>
#include <algorithm>
#define max(a,b) ((a)>(b)?(a):(b))
#define MAXN 10050
using namespace std;
int n, k, pos, sum, root, a, b, c;
struct node
{
int v, w, next;
}edge[MAXN<<1];
int adj[MAXN], dis[MAXN], ans;
bool vis[MAXN];
inline void add(int a,int b,int c)
{
edge[pos].v=b, edge[pos].w=c, edge[pos].next=adj[a];
adj[a]=pos;
++pos;
}
int sz[MAXN], sz2[MAXN];
void getroot(int u,int fa)
{
//在计算重心的过程中,由于子树大小的变化,所以sz[u],sz2[u]要时时更新
sz[u]=1, sz2[u]=0;
for(int p=adj[u], v;~p;p=edge[p].next)
//!vis[v]是防止访问到子树以外的节点
if((v=edge[p].v)!=fa&&!vis[v])
{
getroot(v,u);
sz[u]+=sz[v];
sz2[u]=max(sz2[u],sz[v]);
}
sz2[u]=max(sz2[u],sum-sz[u]);
if(sz2[u]<sz2[root])root=u;
}
//求出子树中每一个点到重心的距离
int d[MAXN], cnt;
void getdis(int u,int fa)
{
d[++cnt]=dis[u];
for(int p=adj[u], v;~p;p=edge[p].next)
//!vis[v]是防止访问到子树以外的节点
if((v=edge[p].v)!=fa&&!vis[v])
{
dis[v]=dis[u]+edge[p].w;
getdis(v,u);
}
}
int cal(int u,int init)
{
dis[u]=init;
cnt=0;
getdis(u,0);
//排序以进行线性扫描
sort(d+1,d+cnt+1);
int l=1, r=cnt, ans=0;
while(l<r)
if(d[l]+d[r]<=k)ans+=r-l++;
else --r;
return ans;
}
void dfs(int u)
{
vis[u]=1;
ans+=cal(u,0);
for(int p=adj[u], v;~p;p=edge[p].next)
if(!vis[(v=edge[p].v)])
{
//对dis[v]赋初值是为了方便计算
//dis(u,v)+2*edge[p].w=dis[u]+edge[p].w+dis[v]
ans-=cal(v,edge[p].w);
sum=sz[v];
getroot(v,root=0);
dfs(root);
}
}
int main()
{
while(~scanf("%d%d",&n,&k)&&n+k)
{
for(int i=1;i<=n;++i)adj[i]=-1, vis[i]=0;
pos=ans=0;
for(int i=1;i<n;++i)
{
scanf("%d%d%d",&a,&b,&c);
add(a,b,c), add(b,a,c);
}
sz2[0]=sum=n;
getroot(1,root=0);
dfs(root);
printf("%d\n",ans);
}
return 0;
}