这种题还考的挺多的(
口胡),我稍微总结了一下,不多,就两个题目其中一个就是 luogu U41492 树上数颜色
另外一个是 东北四省E题 算是这题的加强版
此处提供两种算法:1.启发式合并 2.主席树
关于启发式合并,大家可以看看这片博客 https://www.luogu.org/blog/codesonic/dsu-on-tree
代码略微相似:
#include<bits/stdc++.h>
using namespace std;
//实测比主席树+dfn序还是要慢不少的
//方法偏向暴力
const int N = 1e5+10;
int n;
int cur,nex[N*2],to[N*2],h[N];
int cnt[N],size[N],son[N],c[N],ans[N];
void addedge(int u,int v){
nex[++cur] = h[u];to[cur] = v;h[u] = cur;
}
void dfs1(int u,int fa){
size[u] = 1;
for(int j = h[u]; j; j = nex[j]){
if(to[j] == fa) continue;
dfs1(to[j],u);
size[u]+=size[to[j]];
if(!son[u]||size[son[u]]<size[to[j]]) son[u] = to[j];
}
//printf("u = %d son[u] = %d\n",u,son[u]);
}
int dfs2(int u,int fa,int isson,int keep){
if(keep){//先遍历非重儿子,并保留答案
for(int j = h[u]; j; j = nex[j]){
if(to[j] == fa||to[j] == son[u]) continue;
dfs2(to[j],u,0,1);
}
}
int temp = 0;
if(!keep&&son[u]) temp+=dfs2(son[u],u,1,0);
else if(son[u]) temp+=dfs2(son[u],u,1,1);//再遍历重儿子 根据keep看是否保留答案
for(int j = h[u]; j; j = nex[j]){//暴力获得当前结点非重儿子的cnt
if(to[j] == fa||to[j] == son[u]) continue;
temp+=dfs2(to[j],u,0,0);
}
if(!cnt[c[u]]) temp++;
cnt[c[u]]++;
if(keep) ans[u] = temp;
if(keep && !isson) memset(cnt,0,sizeof(cnt));
return temp;
}
int main(){
scanf("%d",&n);
for(int i = 1; i <= n-1; i++){
int u,v;
scanf("%d%d",&u,&v);
addedge(u,v);
addedge(v,u);
}
for(int i = 1; i <= n; i++) scanf("%d",&c[i]);
dfs1(1,0);
dfs2(1,0,1,1);
int m,k;
scanf("%d",&m);
while(m--){
scanf("%d",&k);
printf("%d\n",ans[k]);
}
return 0;
}
主席树就没啥入门的博客推荐了,这是我自己写的:
#include<bits/stdc++.h>
using namespace std;
const int N = 1e5+10;
int c[N],mp[N],ans[N];
int cnt,t[N*40],R[N*40],L[N*40],root[N];
int cur,l[N],r[N],dfn[N];
int tot,h[N],nex[N*2],to[N*2];
int n,m;
void addedge(int u,int v){
nex[++tot] = h[u];to[tot] = v;h[u] = tot;
}
void dfs(int u,int fa){
l[u] = ++cur;dfn[cur] = u;
for(int i = h[u]; i; i = nex[i]){
if(to[i] == fa) continue;
dfs(to[i],u);
}r[u] = cur;
}
void update(int &o,int last,int l,int r,int pos,int v){
o = ++cnt;
t[o] = t[last]+v;R[o] = R[last];L[o] = L[last];
int mid = l+r>>1;
if(l == r) return;
if(pos<=mid) update(L[o],L[last],l,mid,pos,v);
else update(R[o],R[last],mid+1,r,pos,v);
}
int query(int now,int l,int r,int k){
if(k<=l) return t[now];
if(k>r) return 0;
int mid = l+r>>1;
if(k<=mid) return query(L[now],l,mid,k)+t[R[now]];
else return query(R[now],mid+1,r,k);
}
int main(){
scanf("%d",&n);
for(int i = 1; i <= n-1; i++){
int u,v;
scanf("%d%d",&u,&v);
addedge(u,v);
addedge(v,u);
}
for(int i = 1; i <= n; i++){
scanf("%d",&c[i]);
}
dfs(1,0);
for(int i = 1; i <= n; i++){
int u = dfn[i];
if(!mp[c[u]]){
update(root[i],root[i-1],1,n,i,1);
}else{
update(root[i],root[i-1],1,n,mp[c[u]],-1);
update(root[i],root[i],1,n,i,1);
}
mp[c[u]] = i;
}
for(int i = 1; i <= n; i++){
ans[i] = query(root[r[i]],1,n,l[i]);
}
scanf("%d",&m);
while(m--){
int u;
scanf("%d",&u);
printf("%d\n",ans[u]);
}
return 0;
}
在洛谷上面,主席树跑的更快,那篇博客也说了,启发式合并比较偏暴力,可能没那么快
另外hdu上面的那题比较有趣,你要把一颗树拆成两半,并且统计每一半的颜色个数
这是启发式合并的写法,借鉴博客:https://blog.csdn.net/qq_30358129/article/details/89889688
这个乍一看很难理解,不过你要是动手模拟一下这个过程,会发现特别神奇,照样遵循了重儿子单次遍历,非重儿子二次遍历的原则,而且跑的比主席树还快不少。。。
//需要用到第x个颜色的总数,和第x个颜色在第i棵子树内的个数来判断删除第i棵子树后的答案变化
#include<bits/stdc++.h>
#define rep(i, a, b) for(int i = (a); i <= (b); i++)
#define pb push_back
using namespace std;
typedef long long ll;
const int N=1e5+100;
vector<int> nxt[N];
int n,ans,val[N],cnt[N],tot,dif[N]; //cnt[i]记录i的总共出现次数 ,dif[i]删除第i棵子树的答案,val[i]记录第i个点的权值
unordered_map<int,int> re[N]; //re[i][x]记录第i棵子树中x的出现次数
void init(int n) {
memset(cnt,0,sizeof(cnt));
rep(i, 1, n) {
nxt[i].clear();
re[i].clear();
dif[i] = 0;
}
ans = 0;
tot = 0;
}
int siz[N],son[N];
void dfs(int u,int f) {
son[u] = 0; //重儿子
siz[u] = 1;
for(auto v:nxt[u]) {
if(v==f)continue;
dfs(v,u);
siz[u] += siz[v];
if(siz[son[u]]<siz[v]) son[u] = v; //维护重儿子
}
}
void dfs2(int u,int f) {
for(auto v:nxt[u]) {
if(v==f) continue;
dfs2(v,u);
}
if(siz[u]==1) {
re[u][val[u]] = 1;
dif[u] = tot+1;
if(re[u][val[u]]==cnt[val[u]]) dif[u]--;
ans = max(ans,dif[u]);
}
else {
swap(re[u],re[son[u]]);
dif[u] = dif[son[u]];
for(auto v:nxt[u]) {
if(v==f||v==son[u]) continue;
for(auto x:re[v]) {
if(re[u][x.first]==0) dif[u]++;
re[u][x.first]+=x.second;
if(re[u][x.first]==cnt[x.first]) dif[u]--;
}
}
if(re[u][val[u]]==0) dif[u]++;
re[u][val[u]]++;
if(re[u][val[u]]==cnt[val[u]]) dif[u]--;
ans = max(ans,dif[u]);
}
}
int main() {
// freopen("a.txt","r",stdin);
ios::sync_with_stdio(0);
while(cin>>n) {
init(n);
rep(i, 2, n) {
int f;
cin>>f;
nxt[i].pb(f);
nxt[f].pb(i);
}
rep(i, 1, n) {
cin>>val[i];
if(cnt[val[i]]==0) tot++;
cnt[val[i]]++;
}
dfs(1,0);
dfs2(1,0);
cout<<ans<<endl;
}
}
另外,我结合我学长写的代码,写了一份比较搓的主席树题解
#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
const int N = 2e5+10;
int h[N],v[N<<1],nex[N<<1],cur;
int df[N<<1],cnt,ri[N],li[N];
int tot,L[N*40],R[N*40],t[N*40],root[N<<1];
int n,k;
int val[N],mp[N];
void build(int &o,int l,int r){
o = ++tot;
t[o] = 0;
if(l == r) return;
int mid = l+r>>1;
build(L[o],l,mid);build(R[o],mid+1,r);
}
void init(){
for(int i = 1; i <= n; i++) h[i] = 0;
cur = tot = cnt = 0;
build(root[0],1,2*n);
}
void add_edge(int x,int y){
nex[++cur] = h[x];h[x] = cur;v[cur] = y;
}
void dfs(int u,int f){
df[++cnt] = u; li[u] = cnt;
for(int j = h[u]; j; j=nex[j]){
if(v[j] == f) continue;
dfs(v[j],u);
}ri[u] = cnt;
}
void update(int &o,int last,int l,int r,int pos,int v){
o = ++tot;
t[o] = t[last]+v;
L[o] = L[last],R[o] = R[last];
if(l == r) return;
int mid = l+r>>1;
if(pos<=mid) update(L[o],L[last],l,mid,pos,v);
else update(R[o],R[last],mid+1,r,pos,v);
}
int query(int o,int l,int r,int pos){
if(l==r) return t[o];
int mid = l+r>>1;
if(pos<=mid) return query(L[o],l,mid,pos)+t[R[o]];
else return query(R[o],mid+1,r,pos);
}
int main(){
while(~scanf("%d",&n)){
init();
for(int i = 2; i <= n; i++){
scanf("%d",&k);
add_edge(i,k);
add_edge(k,i);
}
dfs(1,0);
for(int i = 1; i <= n; i++){
scanf("%d",&val[i]);
mp[val[i]] = 0;
}
for(int i = n+1; i <= 2*n; i++) df[i] = df[i-n];
for(int i = 1; i <= 2*n; i++){
int u = df[i];
if(mp[val[u]]){
update(root[i],root[i-1],1,2*n,mp[val[u]],-1);
update(root[i],root[i],1,2*n,i,1);
}else update(root[i],root[i-1],1,2*n,i,1);
mp[val[u]] = i;
}
int ans = 0;
for(int i = 1; i <= n; i++){
int temp = query(root[ri[i]],1,2*n,li[i]);
temp+=query(root[li[i]+n-1],1,2*n,ri[i]+1);
ans = max(ans,temp);
}
printf("%d\n",ans);
}
return 0;
}