题意:
给你一颗树,每一个节点有一个颜色,在节点v的子树中,颜色x出现的次数最多,则称x支配v的子树,注意一颗子树可能被多个颜色支配,让你输出对于每一个节点,支配他的子树的颜色的编号和
解析:
树上启发式合并的模板题
总结一下,对于在做节点x的子树问题时,答案是通过x的重孩子继承过来,然后在遍历x的每一个轻孩子,
来更新答案,使得其变成x的子树问题的答案
状态量取决于他的孩子节点时,如果状态量是一个变量,那么直接从他的重孩子继承
,再遍历轻孩子时更新。对于一维的状态数组,则用STL的二维指针来做,状态数组先指向重孩子的状态数组,
之后在遍历轻孩子时考虑重新申请空间还是在原有重孩子的状态数组上更新
对于这道题,我做烦了...其实不需要set<int> *vis[maxn]这个状态数组的。
因为如果在更新轻孩子时(此时重孩子以遍历过),出现更大的颜色,那么直接重新开始计数
如果没有出现更大的颜色(可能出现与最大值相等的颜色),那么直接将这个颜色加到ans中就好了,
因为这个颜色在重孩子中一定是<最大值的,所以没有被加入到答案过
我的代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll MOD = 1e9+7;
const int maxn = 1e5+10;
int res[maxn];
ll sum[maxn];
set<int> *vis[maxn];
int col[maxn];
vector<int> g[maxn];
int sz[maxn],st[maxn];
int ver[maxn];
void getsz(int v, int p,int& tim){
sz[v] = 1; // every vertex has itself in its subtree
ver[tim] = v;
st[v]=tim;
for(auto u : g[v])
if(u != p){
++tim;
getsz(u, v, tim);
sz[v] += sz[u]; // add size of child u to its parent(v)
}
}
int cnt[maxn];
void dfs(int v, int p, bool keep){
int mx = -1, bigChild = -1;
for(auto u : g[v])
if(u != p && sz[u] > mx)
mx = sz[u], bigChild = u;
for(auto u : g[v])
if(u != p && u != bigChild)
dfs(u, v, 0); // run a dfs on small childs and clear them from cnt
if(bigChild != -1)
dfs(bigChild, v, 1),res[v]=res[bigChild],sum[v]=sum[bigChild],vis[v]=vis[bigChild]; // bigChild marked as big and not cleared from cnt
else
res[v]=sum[v]=0,vis[v]=new set<int>();
for(auto u : g[v])
if(u != p && u != bigChild)
for(int p = st[u]; p < st[u]+sz[u]; p++)
cnt[ col[ ver[p] ] ]++,res[v]=max(res[v],cnt[ col[ ver[p] ] ]);
cnt[ col[v] ]++,res[v]=max(res[v],cnt[ col[v] ]);
if(bigChild != -1&&res[v]>res[bigChild])
{
vis[v]=new set<int>();
sum[v]=0;
}
for(auto u : g[v])
if(u != p && u != bigChild)
for(int p = st[u]; p < st[u]+sz[u]; p++)
if(cnt[col[ver[p]]]==res[v]&&!(*vis[v]).count(col[ver[p]])) sum[v]+=col[ver[p]],(*vis[v]).insert(col[ver[p]]);
if(cnt[col[v]]==res[v]&&!(*vis[v]).count(col[v])) sum[v]+=col[v],(*vis[v]).insert(col[v]);
//now cnt[c] is the number of vertices in subtree of vertex v that has color c. You can answer the queries easily.
if(keep == 0)
for(int p = st[v]; p < st[v]+sz[v]; p++)
cnt[ col[ ver[p] ] ]--;
}
int main()
{
int n;
scanf("%d",&n);
for(int i=1;i<=n;i++)
{
scanf("%d",&col[i]);
}
for(int i=1;i<n;i++)
{
int u,v;
scanf("%d%d",&u,&v);
g[u].push_back(v);
g[v].push_back(u);
}
int tim=1;
getsz(1,0,tim);
dfs(1,0,1);
for(int i=1;i<=n;i++) printf("%lld ",sum[i]);
}
大佬的代码(优化了我上面说的那个情况)
//God & me // ya mahdi adrekni
//@Shayan_Cheshmjahan: Oh my friend, congratulations!
#include <bits/stdc++.h>
#define pb push_back
using namespace std;
const int maxn=1e5+12;
int n,col[maxn],sz[maxn],maxx,cnt[maxn];
long long ans[maxn],can;
bool badboy[maxn];
vector<int>g[maxn];
void dastan(int v=0,int p=-1){
sz[v]=1;
for(auto &u:g[v])
if(u!=p)dastan(u,v),sz[v]+=sz[u];
}
void reval(int x){
if(maxx<++cnt[x])
maxx=cnt[x],can=x;
else if(maxx==cnt[x])
can+=x;
}
void add(int v,int p=-1){
reval(col[v]);
for(auto &u:g[v])
if(u!=p && !badboy[u])
add(u,v);
}
void rem(int v,int p=-1){
cnt[col[v]]--;
for(auto &u:g[v])
if(u!=p && !badboy[u])
rem(u,v);
}
void dfs(int v=0,int p=-1,bool hrh=0){//HRH?! mipasandam!
int mx=0,big=-1;
for(auto &u:g[v])
if(u!=p && sz[u]>mx)
mx=sz[u],big=u;
for(auto &u:g[v])
if(u!=p && u!=big)
dfs(u,v,1);
if(big+1)
dfs(big,v),badboy[big]=1;
add(v,p);
if(big+1)
badboy[big]=0;
ans[v]=can;
if(hrh)
rem(v,p),maxx=can=0;
}
main(){
scanf("%d",&n);
for_each(col,col+n,[](int &x){scanf("%d",&x);});
for(int i=1,v,u;i<n;i++)scanf("%d%d",&v,&u),g[--v].pb(--u),g[u].pb(v);
dastan(),dfs();
for_each(ans,ans+n,[](long long &x){printf("%lld ",x);});putchar('\n');
return 0;
}