题意:一棵树有n个结点,每个结点都是一种颜色,每个颜色有一个编号,求树中每个子树的最多的颜色编号的和。
考虑如何用dsu on tree来做。
我们不难想到一个暴力的做法,用一个全局sum数组记录颜色出现次数,对每个节点,遍历其所有子树,然后统计答案,最后再清空sum数组(不清空的话会对影响到其他节点),这个时间复杂度是O(n^2),肯定过不去。
然后我们可以考虑一个优化,遍历到最后一个子树时是不用清空的,因为它不会产生对其他节点影响了,根据贪心的思想我们当然要把节点数最多的子树(即重儿子形成的子树)放在最后,之后我们就有了一个看似比较快的算法,先遍历所有的轻儿子节点形成的子树,统计答案但是不保留数据,然后遍历重儿子,统计答案并且保留数据,最后再遍历轻儿子以及父节点,合并重儿子统计过的答案。
看似这个优化不是很重要,但是时间复杂度会变成O(nlogn)。
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N= 1e5 + 5 ;
int n, u, v, tot, maxn, now;
int a[ N] , in[ N] , out[ N] , id[ N<< 1 ] , size[ N] , d[ N] , son[ N] , sum[ N] , ans[ N] ;
int cnt, head[ N] ;
struct edge{ int next, to; } e[ N<< 1 ] ;
inline void add ( int u, int v)
{
cnt++ ;
e[ cnt] . next= head[ u] ;
e[ cnt] . to= v;
head[ u] = cnt;
}
void dfs1 ( int u, int fa)
{
in[ u] = ++ tot; id[ tot] = u; size[ u] = 1 ;
for ( register int i= head[ u] ; i; i= e[ i] . next)
if ( e[ i] . to!= fa)
{
d[ e[ i] . to] = d[ u] + 1 ;
dfs1 ( e[ i] . to, u) ;
size[ u] + = size[ e[ i] . to] ;
if ( size[ son[ u] ] < size[ e[ i] . to] ) son[ u] = e[ i] . to;
}
out[ u] = tot;
}
inline void ins ( int u, int fa)
{
sum[ a[ u] ] ++ ;
if ( sum[ a[ u] ] > maxn) maxn= sum[ a[ u] ] , now= a[ u] ;
else if ( sum[ a[ u] ] == maxn) now+ = a[ u] ;
for ( register int i= head[ u] ; i; i= e[ i] . next)
if ( e[ i] . to!= fa && e[ i] . to!= son[ u] )
{
for ( register int j= in[ e[ i] . to] ; j<= out[ e[ i] . to] ; ++ j)
{
int x= id[ j] ;
sum[ a[ x] ] ++ ;
if ( sum[ a[ x] ] > maxn) maxn= sum[ a[ x] ] , now= a[ x] ;
else if ( sum[ a[ x] ] == maxn) now+ = a[ x] ;
}
}
}
inline void del ( int u)
{
now= maxn= 0 ;
for ( register int i= in[ u] ; i<= out[ u] ; ++ i)
{
int x= id[ i] ;
sum[ a[ x] ] -- ;
}
}
void dfs ( int u, int fa, int pd)
{
for ( register int i= head[ u] ; i; i= e[ i] . next)
if ( e[ i] . to!= fa && e[ i] . to!= son[ u] ) dfs ( e[ i] . to, u, 0 ) ;
if ( son[ u] ) dfs ( son[ u] , u, 1 ) ;
ins ( u, fa) ;
ans[ u] = now;
if ( ! pd) del ( u) ;
}
signed main ( ) {
scanf ( "%lld" , & n) ;
for ( register int i= 1 ; i<= n; ++ i) scanf ( "%lld" , & a[ i] ) ;
for ( register int i= 1 ; i< n; ++ i) scanf ( "%lld%lld" , & u, & v) , add ( u, v) , add ( v, u) ;
dfs1 ( 1 , 0 ) ;
dfs ( 1 , 0 , 0 ) ;
for ( register int i= 1 ; i<= n; ++ i) printf ( "%lld " , ans[ i] ) ;
return 0 ;
}