暴力枚举即可
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef LL lint;
const lint inf = 0x3f3f3f3f;
const lint maxn = 100005;
const lint maxm = 200005;
lint a[maxn],tot,he[maxn],ne[maxm],ver[maxm],sum,n;
void add( lint x,lint y ){
ver[++tot] = y;
ne[tot] = he[x];
he[x] = tot;
}
vector<lint> ve[maxn];
void dfs_pre( lint x,lint f ){
for( lint cure = he[x];cure;cure = ne[cure] ){
lint y = ver[cure];
if( y == f ) continue;
ve[x].push_back(a[y]);
if( a[x] == a[y] ) {
sum++;
}
dfs_pre(y,x);
}
sort( ve[x].begin(),ve[x].end() );
}
void solve_pre( lint x ){
lint pre = 1;
lint re = 0;
for( lint i = 1;i < ve[x].size();i++ ){
if( ve[x][i] != ve[x][i-1] ) {
re += pre*(pre-1)/2;
pre=1;
}
else pre++;
}
re += pre*(pre-1)/2;
sum +=re;
}
void pre_work(){
dfs_pre(1,0);
for( lint i = 1;i <= n;i++ ){
solve_pre(i);
}
}
lint ans[maxn];
lint dfs_solve( lint x,lint y ){
lint cur = 0,px = 0,py = 0;
lint res = 0;
while( px < ve[x].size() ){
if( x >= 1 && ve[x][px]==ve[x][px-1] ){
res += cur;
}else {
cur = 0;
while( py < ve[y].size() && ve[y][py] < ve[x][px] ) py++;
while (py < ve[y].size() && ve[y][py] == ve[x][px]) {
cur++;
py++;
}
res += cur;
}
px++;
}
return res;
}
lint dfs_solve2( lint c,lint x ){
lint p1 = lower_bound( ve[x].begin(),ve[x].end(),c )-ve[x].begin();
lint p2 = upper_bound( ve[x].begin(),ve[x].end(),c )-ve[x].begin();
return p2-p1;
}
void dfs( lint x,lint f ){
ans[x]=sum;
if( a[x] == a[f] ) ans[x]--;
//ans[x]-=res[x];
ans[x] -= dfs_solve2( a[x],x );
if( x != 1 )
ans[x] -= dfs_solve2( a[x],f )-1;
lint mm = -inf;
bool flag = false;
for( lint cure = he[x];cure;cure = ne[cure] ){
lint y = ver[cure];
if( y == f )continue;
flag = true;
lint cur = 0;
if( a[y]==a[f] ) cur++;
cur += dfs_solve( x,y );
cur -= dfs_solve2(a[y],y);
cur += dfs_solve2( a[y],f ) - (( a[y] == a[x] )?1:0) ;
mm = max( mm,cur );
}
if(!flag) mm = 0;
ans[x] += mm;
for( lint cure = he[x];cure;cure = ne[cure] ){
lint y= ver[cure];
if( y == f ) continue;
dfs(y,x);
}
}
int main(){
lint m;
tot = 1;
scanf("%lld%lld",&n,&m);
for( lint i = 1;i <= n;i++ ) scanf("%lld",&a[i]);
for( lint x,y,i = 1;i <= n-1;i++ ){
scanf("%lld%lld",&x,&y);
add(x,y);add(y,x);
}
pre_work();
dfs(1,0);
printf("%lld",ans[1]);
for( lint i = 2;i <= n;i++ ){
printf(" %lld",ans[i]);
}
cout << endl;
return 0;
}