所以树上分治是什么呢?树上分治是解决一类树上路径计数的问题。
举个例子,我现在要知道树上有多少条路的距离小于等于K。
如果直接取点判断肯定是不行的,我们考虑一条路径,肯定有两种可能,一个是经过根节点,或者不在根节点。那我们就可以先找一个根节点,将一颗树转化成一个有根树。然后我们对以这点为中心,计算有多少个符合的路径。当然这一次计算会发现把子树上的也算进去了。删掉这个点,然后我们再遍历所有的子树,先减去子树的答案(容斥)然后再找一个根节点递归下去。
所以,树分治整个过程可以分为一下几个部分:
1.怎么统计答案?
2.再树上找到一个合适点作为根节点
3.递归的取对每个子树计数。
第一个部分当然要具体题目具体分析。所以如何找到一个合适的点?我们想让这个点子树的大小最小,这里我们引入树的重心:这个点所有子树的最大值最小。这个重心有一个性质:树中所有点到某个点的距离和中,到重心的距离和是最小的,如果有两个距离和,他们的距离和一样。我们利用这个点就可以很方便的进行第一步的计数了。
接下来的递归就很简单了,我们来看看具体的例题吧:
POJ - 1741
这个就是上面说的题,我们找到重心之后,先更新所有点到中心的距离,对所有的点按照离重心的距离排序,用双指针判断,如果一个区间是符合的,那么按照排序之后的结果,中间的也是符合要求的
#include <iostream>
#include <cstring>
#include <vector>
#include <cstdio>
#include <algorithm>
#define ll long long
#define next fuck
#define mp make_pair
#define pb push_back
#define INF 0x3f3f3f3f
#define mm(a,b) memset(a,b,sizeof(a))
using namespace std;
const int maxn=1e4+50;
struct Node{
int v,w,next;
}edge[2*maxn];
int head[maxn],tol;//节点从1开始计数
int son[maxn],f[maxn],vis[maxn];
int dep[maxn],siz,d[maxn];
int cntv,root,K,ans;
void init(){
mm(head,-1);tol=0;
mm(vis,0);
f[0]=INF;//一定要定义成最大的值
}
void addedge(int u,int v,int w){
edge[tol].v=v;
edge[tol].w=w;
edge[tol].next=head[u];
head[u]=tol++;
}
void getroot(int u,int fa){//寻找重心
son[u]=1,f[u]=0;
for(int i=head[u];i!=-1;i=edge[i].next){
int v=edge[i].v;
if(v==fa||vis[v]) continue;
getroot(v,u);
son[u]+=son[v];
f[u]=max(f[u],son[v]);
}
f[u]=max(f[u],cntv-son[u]);//sum表示当前树的大小
if(f[u]<f[root]) root=u;//更新当前重心
}
void getdepth(int u,int fa){
dep[siz++]=d[u];
for(int i=head[u];i!=-1;i=edge[i].next){
int v=edge[i].v;
if(v==fa||vis[v]) continue;
d[v]=d[u]+edge[i].w;
getdepth(v,u);
}
}
int cal(int u,int w){
d[u]=w;
siz=0;
getdepth(u,0);
sort(dep,dep+siz);
int l=0,r=siz-1,res=0;
while(l<r){
if(dep[l]+dep[r]<=K){
res+=r-l;
l++;
}
else r--;
}
return res;
}
void solve(int u){//计算以u为重心的树
ans+=cal(u,0);
vis[u]=1;
for(int i=head[u];i!=-1;i=edge[i].next){
int v=edge[i].v;
if(vis[v]) continue;
ans-=cal(v,edge[i].w);//减去不合法的数目
cntv=son[v];
root=0;
getroot(v,0);
solve(root);
}
}
int main()
{
int n;
while(scanf("%d%d",&n,&K) != EOF)
{
if(n == 0 && K == 0) break;
init();
for(int i = 0;i<n-1;i++)
{
int u,v,w;
scanf("%d%d%d",&u,&v,&w);
addedge(u,v,w);
addedge(v,u,w);
}
cntv = n;
root = 0;
getroot(1
,0);
ans = 0;
solve(root);
printf("%d\n",ans);
}
return 0;
}
HDU - 4812 D tree
这道题稍难一点(树分没有简单题),求有多少条路径上点权乘积对mod取模,结果是k。
首先我们先思考怎么找乘积,我们还是先更新每个颠倒重心的权值乘积取模,对于一个乘积,我们可以选择map记录他对K的逆元,每次查询一次之后清空整个map
代码写的太丑就不放了
bzoj 2152
判断路径边权和为3的倍数的有多少,答案和n2取个最简分数。
这个模数比较小,我们保存每次对中心取距离之后,余数的情况。利用余数的情况就能排列组合一搞,答案就出来了。
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <map>
#include <iostream>
#include <algorithm>
#include <cmath>
#define next fuck
using namespace std;
const int MAXN = 2e5 + 5;
int n,cnt,ans,root,sum;
int head[MAXN],tot;
int son[MAXN],f[MAXN],d[MAXN],t[5];
bool vis[MAXN];
inline int read()
{
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
struct Edge
{
int to,next,w;
} edge[MAXN*2];
void addedge(int u,int v,int w)
{
edge[tot].to =v;
edge[tot].w = w;
edge[tot].next = head[u];
head[u] = tot++;
}
void getroot(int u,int p)
{
son[u] = 1,f[u] = 0;
for(int i = head[u]; i != -1; i = edge[i].next)
{
int v = edge[i].to;
if(!vis[v] && v != p)
{
getroot(v,u);
son[u] += son[v];
f[u] = max(f[u],son[v]);
}
}
f[u] = max(f[u],sum - son[u]);
if(f[u] < f[root])
root = u;
}
void getdeep(int u,int p)
{
t[d[u]]++;
for(int i = head[u]; i != -1; i = edge[i].next)
{
int v = edge[i].to;
if(!vis[v] && v != p)
{
d[v] = (d[u] + edge[i].w) % 3;
getdeep(v,u);
}
}
}
int calc(int x,int now)
{
t[0] = t[1] = t[2] = 0;
d[x] = now;
getdeep(x,0);
return t[1]*t[2]*2+t[0]*t[0];
}
void solve(int x)
{
ans += calc(x,0);
vis[x] = 1;
for(int i = head[x]; i != -1; i = edge[i].next)
{
int v= edge[i].to;
if(!vis[v])
{
ans -= calc(v,edge[i].w);
root = 0;
sum = son[v];
getroot(v,0);
solve(root);
}
}
}
int main()
{
n = read();
tot = 0;
memset(head,-1,sizeof(head));
for(int i = 1; i<n; i++)
{
int u = read(),v = read(),w = read()%3;
addedge(u,v,w);
addedge(v,u,w);
}
f[0] = sum = n;
getroot(1,0);
solve(root);
int t = __gcd(ans,n*n);
printf("%d/%d\n",ans/t,n*n/t);
return 0;
}