树上启发式合并 学习笔记
树上启发式合并是一种最近几年才出现的黑科技,解决的是这么一类问题:统计树上以每一个节点为根节点的子树的信息,单组询问的话这个问题就没啥意思了,直接暴力 D F S DFS DFS统计就好了,所以一般是多组询问,例如统计子树的颜色种类,或者统计子树的某个指定颜色的节点个数,下面以第二个问题为例阐述树上启发式合并算法,如果文章有任何错误或者不严谨的地方,欢迎各位批评指正
-
先看看其他解法
-
假做法:暴力
- 暴力统计每一颗子树的贡献,即使是暴力,你可能也会发现你的做法会爆空间 233 233 233,原因是你可能开个 d p [ m a x v e r t e x ] [ m a x c o l o r ] dp[max_{vertex}][max_{color}] dp[maxvertex][maxcolor]的数组做树形 d p dp dp,当节点数,颜色种类数量都达到 1 0 5 10^5 105就不行了,考虑另外一种方式:只开一维数组 d p [ m a x c o l o r ] dp[max_{color}] dp[maxcolor],每次暴力 D F S DFS DFS整颗以当前节点为根节点的子树,这时就有了以当前节点为根节点的信息,然后清空数组,继续 D F S DFS DFS(这个思想在树上启发式合并中有用到),分析一下复杂度:由于每个节点平均被统计 n n n次,所以复杂度 O ( n 2 ) O(n^2) O(n2),「顺便剧透一下:树上启发式合并能做到每个节点最多被统计 log n \log n logn次,所以其复杂度为 O ( n log n ) O(n\log n) O(nlogn) 」,这里直接粘贴的 c o d e f o r c e s codeforces codeforces上博客的代码:
-
int cnt[maxn]; void add(int v, int p, int x){ cnt[ col[v] ] += x; for(auto u: g[v]) if(u != p) add(u, v, x) } void dfs(int v, int p){ add(v, p, 1); //now cnt[c] is the number of vertices in subtree of vertex v that has color c. //You can answer the queries easily. add(v, p, -1); for(auto u : g[v]) if(u != p) dfs(u, v); }
-
D
F
S
DFS
DFS序加莫队
- 这个做法就是建 D F S DFS DFS序,然后在DFS序上做莫队,复杂度 O ( n n ) O(n \sqrt n) O(nn)数据量达到 1 0 6 10^6 106可能就会被卡,莫队的这种做法具体就不说了,可以看一看网上的博客
-
-
「树上启发式合并」算法步骤
- 先遍历轻儿子,暴力统计每个以轻儿子为根的子树中所有节点的信息,统计后在数组中不保留信息
- 统计以重儿子为根的子树的所有节点的信息,保留该信息
- 统计以当前节点为根的子树的中的所有节点信息,除去以重儿子为根的子树中的所有节点的信息(因为第二步统计过了)
- 放上一张图:黑色加粗边为重链,其他的为轻链,以当前统计的节点为节点1为例:
-
「树上启发式合并」算法伪码
- 这里也是直接粘贴的codeforces上的代码,如果感觉不太好理解,建议先跳到第一个例题看一下那个代码,有很详细的说明介绍
-
int cnt[maxn]; bool big[maxn]; void add(int v, int p, int x){ cnt[ col[v] ] += x; for(auto u: g[v]) if(u != p && !big[u]) add(u, v, x) } void dfs(int v, int p, bool keep){ int mx = -1, bigChild = -1; for(auto u : g[v]) if(u != p && sz[u] > mx) mx = sz[u], bigChild = u; for(auto u : g[v]) if(u != p && u != bigChild) dfs(u, v, 0); // run a dfs on small childs and clear them from cnt if(bigChild != -1) dfs(bigChild, v, 1), big[bigChild] = 1; // bigChild marked as big and not cleared from cnt add(v, p, 1); //now cnt[c] is the number of vertices in subtree of vertex v that has color c. You can answer the queries easily. if(bigChild != -1) big[bigChild] = 0; if(keep == 0) add(v, p, -1); }
-
「树上启发式合并」算法思想
- 中心思想来源于树链剖分,如果还没学过的话建议先学一下树链剖分
- 我们知道,做树剖的时候,将所有的儿子划分成重儿子和轻儿子,重儿子就是所有的儿子当中其子树大小(节点数量)最大的那个,其他的儿子都为轻儿子,所以每次对于轻儿子
s
o
n
son
son必有:
s
i
z
e
[
s
o
n
]
<
=
s
i
z
e
[
f
a
t
h
e
r
]
2
size[son]<=\frac{size[father]}{2}
size[son]<=2size[father]
这个可以简单的用反证法证明一下,那么就是说 从任意一个节点出发到根节点的简单路径上的轻边数量最多为 l o g n log\ n log n,而我们发现每次树上启发式合并在统计的时候走的是轻边,这样就说明每个节点最多被统计 l o g n log\ n log n次( l o g n log\ n log n条轻边),故复杂度也就是 n l o g n nlog\ n nlog n
-
参考
-
例题
hdu4358
-
题意:
- 统计每一个节点的子树中有多少种颜色
-
解法一 ⇒ \Rightarrow ⇒ dfs序+莫队 ( O ( n n ) ) (O(n\sqrt{n})) (O(nn))
#include<iostream> #include<cstdio> #include<algorithm> #include<vector> #include<cmath> #include<map> using namespace std; const int maxn=100005; //题目数据 vector<int> vec[maxn]; int t,n,k,q,a[maxn],b[maxn],tot,u,v,cnt[maxn],sum,ans[maxn]; //dfs序 int tim,in[maxn],out[maxn],fin[maxn],val[maxn]; template<typename T> inline void read(T &x) { x = 0;int f = 1;char ch=getchar(); while(ch<'0'||ch>'9') { if(ch=='-') f=-1; ch=getchar(); } while(ch>='0'&&ch<='9'){ x=x*10+ch-'0'; ch=getchar(); }x=x*f; } void init() { tim=0; for(int i=1;i<=n;i++) vec[i].clear(); } void dfs(int cur,int fa) { in[cur]=++tim;fin[tim]=cur;val[tim]=a[cur]; for(int i=0;i<vec[cur].size();i++){ if(vec[cur][i]!=fa){ dfs(vec[cur][i],cur); } } out[cur]=tim; } struct node{ int l,r,id,block; node(int a=0,int b=0,int c=0){ l=a;r=b;id=c;block=c/sqrt(n); } friend bool operator<(const node &a,const node &b){ return a.block==b.block?a.r<b.r:a.l<b.l; } }query[maxn]; void del(int pos) { if(cnt[val[pos]]==k) sum--; cnt[val[pos]]--; if(cnt[val[pos]]==k) sum++; } void add(int pos) { if(cnt[val[pos]]==k) sum--; cnt[val[pos]]++; if(cnt[val[pos]]==k) sum++; } void modui() { sum=0;int l=1,r=0; for(int i=1;i<=tot;i++) cnt[i]=0; for(int i=1;i<=q;i++){ while(l<query[i].l) del(l++); while(l>query[i].l) add(--l); while(r<query[i].r) add(++r); while(r>query[i].r) del(r--); ans[query[i].id]=sum; } } int main() { read(t); for(int cas=1;cas<=t;cas++){ read(n);read(k); init(); for(int i=1;i<=n;i++) read(a[i]),b[i]=a[i]; sort(b+1,b+n+1); tot=unique(b+1,b+n+1)-b-1; for(int i=1;i<=n;i++) a[i]=lower_bound(b+1,b+tot+1,a[i])-b; for(int i=1;i<n;i++){ read(u);read(v); vec[u].push_back(v); vec[v].push_back(u); } dfs(1,0); read(q); for(int i=1;i<=q;i++) { read(u); query[i]=node(in[u],out[u],i); } sort(query+1,query+q+1); modui(); printf("Case #%d:\n",cas); for(int i=1;i<=q;i++) printf("%d\n",ans[i]); if(cas<t) printf("\n"); } }
-
解法二 ⇒ \Rightarrow ⇒ 树上启发式合并 ( O ( n log n ) ) (O(n \log n)) (O(nlogn))
#include<cstdio> #include<algorithm> #include<cstring> #include<vector> using namespace std; const int maxn=100005; //题目数据 int t,n,u,v,q,k,a[maxn],b[maxn],ans[maxn],cnt[maxn],sum,m; vector<int> vec[maxn]; //树剖用 int tot=0,siz[maxn],son[maxn]; int vis[maxn]; void dfs1(int cur,int fath,int he){ //dfs(root,0,1) siz[cur]=1; for(int i=0;i<vec[cur].size();i++){ if(vec[cur][i]!=fath){ dfs1(vec[cur][i],cur,he+1); siz[cur]+=siz[vec[cur][i]]; if(siz[vec[cur][i]]>siz[son[cur]]) son[cur]=vec[cur][i]; } } } void calc(int cur,int fa,int val) { if(cnt[a[cur]]==k) sum--; cnt[a[cur]]+=val; if(cnt[a[cur]]==k) sum++; for(int i=0;i<vec[cur].size();i++){ if(vec[cur][i]!=fa&&!vis[vec[cur][i]]){ //!vis[vec[cur][i]]表示不计算重儿子的贡献,因为已经计算过 calc(vec[cur][i],cur,val); } } } void dfs(int cur,int fa,bool keep) //keep表示以当前节点为根节点的子树的贡献是否保留 { for(int i=0;i<vec[cur].size();i++){ if(vec[cur][i]!=fa&&vec[cur][i]!=son[cur]){ dfs(vec[cur][i],cur,0);//计算轻链的结果,并且不保存 } if(son[cur]) dfs(son[cur],cur,1),vis[son[cur]]=1; //计算重儿子贡献,并打上标记,防止下一次calc的时候重复计算 calc(cur,fa,1);//计算当前节点和所有以轻儿子为根节点的子树的贡献 ans[cur]=sum; //这里就计算好了整颗以当前节点为根节点的子树的贡献 if(son[cur]) vis[son[cur]]=0; //消除标记,防止对下一步的清空操作产生影响,因为如果当前节点是轻儿子的话,那么整颗以 //当前节点为根节点的子树的贡献都要清空,如果不把重儿子的标记去掉,那么以重儿子为根的 //子树的贡献就无法在下一步清空 if(!keep) calc(cur,fa,-1); //如果当前节点作为父节点的轻儿子,那么消除以当前节点为根节点的子树的影响 } void init() { tot=0;sum=0; memset(son,0,sizeof(son)); memset(cnt,0,sizeof(cnt)); memset(vis,0,sizeof(vis)); for(int i=1;i<=n;i++) vec[i].clear(); } int main() { scanf("%d",&t); for(int cas=1;cas<=t;cas++){ scanf("%d %d",&n,&k);init(); for(int i=1;i<=n;i++) scanf("%d",&a[i]),b[i]=a[i]; sort(b+1,b+n+1); m=unique(b+1,b+n+1)-b-1; for(int i=1;i<=n;i++) a[i]=lower_bound(b+1,b+m+1,a[i])-b; for(int i=1;i<n;i++){ scanf("%d %d",&u,&v); vec[u].push_back(v); vec[v].push_back(u); } dfs1(1,0,1); dfs(1,0,0); scanf("%d",&q); printf("Case #%d:\n",cas); for(int i=1;i<=q;i++){ scanf("%d",&u); printf("%d\n",ans[u]); } if(cas<t) printf("\n"); } }
CF600E- Lomsat gelral
You are given a rooted tree with root in vertex 1 1 1. Each vertex is coloured in some colour.
Let’s call colour c c c dominating in the subtree of vertex v v v if there are no other colours that appear in the subtree of vertex v v v more times than colour c c c. So it’s possible that two or more colours will be dominating in the subtree of some vertex.
The subtree of vertex v v v is the vertex v v v and all other vertices that contains vertex v v v in each path to the root.
For each vertex v v v find the sum of all dominating colours in the subtree of vertex v v v.
Input
The first line contains integer n ( 1 ≤ n ≤ 1 0 5 ) n (1 ≤ n ≤ 10^5) n(1 ≤ n ≤ 105) — the number of vertices in the tree.
The second line contains n integers c i ( 1 ≤ c i ≤ n ) , c i c_i (1 ≤ c_i ≤ n), c_i ci(1 ≤ ci ≤ n),ci — the colour of the i i i-th vertex.
Each of the next n − 1 n - 1 n − 1 lines contains two integers x j , y j ( 1 ≤ x j , y j ≤ n ) x_j, y_j (1 ≤ x_j, y_j ≤ n) xj, yj(1 ≤ xj, yj ≤ n) — the edge of the tree. The first vertex is the root of the tree.
Output
Print n n n integers — the sums of dominating colours for each vertex.
Examples
input
4
1 2 3 4
1 2
2 3
2 4
output
10 9 3 4
input
15
1 2 3 1 2 3 3 1 1 3 2 2 1 2 3
1 2
1 3
1 4
1 14
1 15
2 5
2 6
2 7
3 8
3 9
3 10
4 11
4 12
4 13
output
6 5 4 3 2 3 3 1 1 3 2 2 1 2 3
-
题意
- 统计以每一个节点的子树中颜色数量最多的颜色编号之和(每种颜色只算一次)
-
解法
- 当然莫队应该是可以的 ,这是树上启发式合并的入门题
-
代码
#include<bits/stdc++.h> using namespace std; const int maxn=1e5+10; //题目数据 int col[maxn],n,u,v,vis[maxn],cnt[maxn],maxx=0;//sum表示颜色个数为i的节点颜色之和,cnt表示颜色为i的节点数量 vector<int> vec[maxn]; long long ans[maxn],sum[maxn]; //树剖用 int tot=0,siz[maxn],son[maxn]; void dfs1(int cur,int fa) { siz[cur]=1; for(int i=0;i<vec[cur].size();i++){ if(vec[cur][i]!=fa){ dfs1(vec[cur][i],cur); siz[cur]+=siz[vec[cur][i]]; if(siz[vec[cur][i]]>siz[son[cur]]) son[cur]=vec[cur][i]; } } } void calc(int cur,int fa,int val) { sum[cnt[col[cur]]]-=col[cur]; cnt[col[cur]]+=val; sum[cnt[col[cur]]]+=col[cur]; if(cnt[col[cur]]>maxx){ maxx=cnt[col[cur]]; } if(sum[maxx]==0) maxx=cnt[col[cur]]; for(int i=0;i<vec[cur].size();i++){ if(vec[cur][i]!=fa&&!vis[vec[cur][i]]){ calc(vec[cur][i],cur,val); } } } void dfs(int cur,int fa,int keep) { for(int i=0;i<vec[cur].size();i++){ if(vec[cur][i]!=fa&&vec[cur][i]!=son[cur]){ dfs(vec[cur][i],cur,0); } } if(son[cur]) dfs(son[cur],cur,1),vis[son[cur]]=1; calc(cur,fa,1); ans[cur]=sum[maxx]; if(son[cur]) vis[son[cur]]=0; if(!keep) calc(cur,fa,-1); } int main() { scanf("%d",&n); for(int i=1;i<=n;i++) scanf("%d",&col[i]); for(int i=1;i<n;i++){ scanf("%d %d",&u,&v); vec[u].push_back(v); vec[v].push_back(u); } dfs1(1,0); dfs(1,0,1); for(int i=1;i<=n;i++) printf("%lld%c",ans[i],i==n?'\n':' '); }