树的分治算法常见两种,一是点分治,二是边分治,本文只考虑点分治。
点分治,顾名思义,首先选取一个点将无根树转换为有根树,再递归处理每一棵以根节点的儿子为根的子树。
对于树的分治算法来说,递归的深度往往决定着算法效率的高低,所以,该如何选取这个根节点呢?最坏的情况,树退化成链,选取的根节点是链头,时间复杂度O(n)。所以,我们选取的根节点要保证最大的子树最小,也就是选择树的重心作为根节点。递归深度最坏为O(logn)
考虑树上点对的问题,例题求树上是否存在两点满足距离等于k,点数n<=1e4, 询问次数m<=1e2
已知当前根节点为r,路径存在两种情况:
1.经过r
2.被r的某一棵子树包含,不经过r
对于第二种情况的路径又可以进行递归处理,看作是子问题进行分治,变成第一种情况。
所以只用考虑第一种路径的计算。
假设有点x、y分属于不同的子树,路径(x,y)看作(x,r)+(r,y),通过dfs()计算两点到根的距离,即判断dis[x]+dis[y]==k。为了防止误判第二种情况,记录x、y所属的子树根节点sub_root[],求dis[x]+dis[y] == k && sub_root[x] != sub_root[y] 。计算过程见Calc()部分
处理点对后删除当前根节点(打标记),对所有子树递归处理。
下附P3806参考程序
//P3806
#include <bits/stdc++.h>
using namespace std;
const int maxn = 10005;
const int maxm = 105;
struct E
{
int v,w,ne;
}edge[maxn<<1];
int tot = 1,head[maxn];
int n,m,k[maxm];
bool ok[maxm];
void add(int u,int v,int w)
{
edge[++tot] = (E){v,w,head[u]};
head[u] = tot;
}
void read()
{
int u,v,w;
scanf("%d%d",&n,&m);
for(int i=1;i<n;i++)
{
scanf("%d%d%d",&u,&v,&w);
add(u,v,w);
add(v,u,w);
}
for(int i=1;i<=m;i++)
{ scanf("%d",&k[i]);
if(!k[i])
ok[i] = true;
}
}
bool vis[maxn],w[maxn];
int sz[maxn],pos,ans;
long long dis[maxn];//不开ll也没问题,最坏情况退化成链max(n)*max(w) = 1e8
int sub_root[maxn],sub_node[maxn],total_node = 0;
void dfs_root(int x,int S)//重心
{
int y,max_part = 0;
sz[x] = 1,vis[x] = 1;
for(int i=head[x];i;i=edge[i].ne)
{
y = edge[i].v;
if(vis[y] || w[y]) continue;
dfs_root(y,S);
sz[x]+=sz[y];
max_part = max(max_part,sz[y]);
}
max_part = max(max_part,S-sz[x]);
if(max_part < ans)
{
ans = max_part;
pos = x;
}
}
void dfs(int x)//到根节点的距离
{
int y,z;
vis[x] = 1;
for(int i=head[x];i;i=edge[i].ne)
{ y = edge[i].v, z = edge[i].w;
if(vis[y] || w[y]) continue;
dis[y] = dis[x]+z;
sub_node[++total_node] = y;
sub_root[y] = sub_root[x];
dfs(y);
}
}
bool cmp(int x,int y)
{
return dis[x] < dis[y];
}
void work(int x,int S)
{ ans = S;
memset(vis,0,sizeof(vis));
dfs_root(x,S);//找重心pos
memset(vis,0,sizeof(vis));
memset(dis,0,sizeof(dis));
total_node = 0;
sub_root[pos] = sub_node[++total_node] = pos;
w[pos] = 1;
for(int i=head[pos];i;i=edge[i].ne)//处理以根的儿子为根的子树
{ register int y = edge[i].v,z = edge[i].w;
if(vis[y] || w[y]) continue;
dis[y] = z;
sub_root[y] = y,sub_node[++total_node] = y;
dfs(y);
}
//Calc(),判断所有的k,用二分查找速度会更快
sort(sub_node+1,sub_node+total_node+1,cmp);//按照点到子树根的距离排序
for(int i=1;i<=m;i++)
{
if(ok[i]) continue;
int l = 1,r = total_node;
while(l<r)
{ if(dis[sub_node[r]]>k[i])
r--;
else if(dis[sub_node[l]]+dis[sub_node[r]] < k[i])
l++;
else if(dis[sub_node[l]]+dis[sub_node[r]] > k[i])
r--;
else if(sub_root[sub_node[l]] == sub_root[sub_node[r]])//距离之和相同,但是属于同一颗子树
{
if(dis[sub_node[r]] == dis[sub_node[r-1]]) r--;
else l++;
}
else
{
ok[i] = true;
break;
}
}
}
//分治
for(int i=head[pos];i;i=edge[i].ne)
{
if(w[edge[i].v]) continue;
work(edge[i].v,sz[edge[i].v]);
}
}
int main()
{
//freopen("test.in","r",stdin);
read();
work(1,n);
for(int i=1;i<=m;i++)
{
printf("%s\n",ok[i]?"AYE":"NAY");
}
return 0;
}
CF161D 计算有多少个点对距离是k
k值很小,采用桶记录
#include <bits/stdc++.h>
using namespace std;
#define ll long long
const int maxn = 50005;
struct E
{
int v,w,ne;
}edge[maxn<<1];
int n,k,tot = 1,head[maxn];
void add(int u,int v)
{
//edge[++tot] = (E){v,1,head[u]};
edge[++tot].v =v,edge[tot].w = 1,edge[tot].ne=head[u];
head[u] = tot;
}
void read()
{ int x,y;
scanf("%d%d",&n,&k);
for(int i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
add(x,y);
add(y,x);
}
}
int sz[maxn],ans,pos;
bool v[maxn],w[maxn];
void find_root(int x,int S)
{
int y,max_part = 0;
v[x] = 1;
sz[x] = 1;
for(int i=head[x];i;i=edge[i].ne)
{
y = edge[i].v;
if(v[y] || w[y]) continue;
find_root(y,S);
sz[x]+=sz[y];
max_part = max(max_part,sz[y]);
}
max_part = max(max_part,S-sz[x]);
if(max_part < ans)
{
ans = max_part;
pos = x;
}
}
int dis[maxn];
ll ANS;
ll bucket[505];
void dfs(int x)
{ int y;
if(dis[x] > k) return;
bucket[dis[x]]++;//桶排序
v[x] = 1;
//printf("(%d)",x);
for(int i=head[x];i;i=edge[i].ne)
{
y = edge[i].v;
if(v[y] || w[y]) continue;
dis[y] = dis[x] + 1;
dfs(y);
}
}
ll calc()
{ /*for(int i=0;i<=k;i++)
{
printf("%lld ",bucket[i]);
}
cout << endl;*/
ll ret=0;
for(int i = 0;i <= k;i++)
{
if(i >= k-i) break;
ret+=bucket[i]*bucket[k-i];
}
if(k%2==0)//偶数
{
ret+=(bucket[k/2]*(bucket[k/2]-1))/2;
}
return ret;
}
void work(int x,int S)
{ ans = S;
memset(v,0,sizeof(v));
find_root(x,S);
memset(v,0,sizeof(v));
memset(dis,0,sizeof(dis));
memset(bucket,0,sizeof(bucket));
w[pos] = true;
dfs(pos);
//ANS+=calc();
//cout << pos << ":";
ll tmp = calc();
ANS+=tmp;
memset(dis,0,sizeof(dis));
memset(v,0,sizeof(v));
for(int i=head[pos];i;i=edge[i].ne)
{
int y = edge[i].v;
if(v[y] || w[y]) continue;
dis[y] = 1;
memset(bucket,0,sizeof(bucket));
dfs(y);
//ANS-=calc();//减去子树内的重复计算
// cout << y << ":";
tmp = calc();
ANS-=tmp;
}
for(int i=head[pos];i;i=edge[i].ne)
{
int y = edge[i].v;
if(w[y]) continue;
work(y,sz[y]);
}
}
int main()
{
read();
work(1,n);
printf("%I64d",ANS);
return 0;
}
P2634 【国家集训队】聪明可可
除了用容斥原理两次dfs减去重复计算或者是标记根节点直接减去的方法,也可以考虑根据递归的顺序,记录答案= 当前子树x已经遍历的子树,从而避免重复计算第二种路径。
符合要求的路径一定是端点在以根节点儿子为根的不同子树里,遍历时按照子树顺序依次遍历。
用disx[]数组记录下之前已经遍历的子树结果,disy[]记录当前遍历的子树,计算答案ANS+=disx[]*disy[]
最后把当前子树的遍历结果合并入已经遍历的子树结果数组disx[].
#include <bits/stdc++.h>
using namespace std;
const int maxn = 2e4+5;
struct E
{
int v,w,ne;
}edge[maxn<<1];
int tot = 1, head[maxn];
int n;
int gcd(int x,int y)
{
return y==0?x : gcd(y,x%y);
}
void add(int u,int v,int w)
{
edge[++tot] = (E){v,w,head[u]};
head[u] = tot;
}
void read()
{ int u,v,w;
scanf("%d",&n);
for(int i=1;i<n;i++)
{
scanf("%d%d%d",&u,&v,&w);
add(u,v,w);
add(v,u,w);
}
}
int sz[maxn],pos,ans;
bool vis[maxn],w[maxn];
void find_root(int x,int S)
{
int y,max_part = 0;
sz[x] = vis[x] = 1;
for(int i=head[x];i;i=edge[i].ne)
{
y = edge[i].v;
if(vis[y] || w[y]) continue;
find_root(y,S);
sz[x]+=sz[y];
max_part = max(max_part,sz[y]);
}
max_part = max(max_part,S-sz[x]);
if(max_part < ans)
{
ans = max_part;
pos = x;
}
}
int ANSX=0,ANSY=0,disx[3],disy[3];
void dfs(int x,int dist)
{ int y;
vis[x] = 1;
disy[dist]++;
for(int i=head[x];i;i=edge[i].ne)
{
y = edge[i].v;
if(vis[y] || w[y]) continue;
dfs(y,(dist+edge[i].w)%3);
}
}
void work(int x,int S)
{
ans = S;
memset(vis,0,sizeof(vis));
find_root(x,S);
w[pos] = 1;
memset(vis,0,sizeof(vis));
disx[0] = disx[1] = disx[2] = 0;
disy[0] = disy[1] = disy[2] = 0;
int y;
for(int i=head[pos];i;i=edge[i].ne)
{ y = edge[i].v;
if(vis[y] || w[y]) continue;
dfs(y,edge[i].w%3);
ANSX += disx[0]*disy[0]*2 + disx[1]*disy[2]*2 + disx[2]*disy[1]*2+disy[0]*2;//disy[0]:根节点出发长度为3
for(int j=0;j<3;j++)
disx[j]+=disy[j],disy[j] = 0;
}
for(int i = head[pos];i;i=edge[i].ne)
{ y = edge[i].v;
if(w[y]) continue;
work(y,sz[edge[i].v]);
}
}
int main()
{ freopen("test.in","r",stdin);
read();
work(1,n);
ANSX += n;//两点重合
ANSY = n*n;
int tmp = gcd(ANSX,ANSY);
printf("%d/%d",ANSX/tmp, ANSY/tmp);
return 0;
}