题意
给出一个以1 为根的树,对于每个节点,需要输出每个节点的子树中,相同颜色的距离之和。
经典的静态查询子树信息 我们可以考虑 dsu on tree
我们用sum[ c[u ]] 表示 统计到 当前结点的深度之和
我们用 num[ c[u] ] 表示 统计到当前节点的该颜色的结点数
我们先想一下 每个子树v对u结点的贡献
sum[ c[ u ] ] − n u m [ c [ u ] ] ∗ d [ u ]
对于结点之间的贡献我们可以暴力的统计
s u m [ c [ v ] ] − n u m [ c [ v ] ] ∗ d [ u ] + ( d [ v ] − d [ u ] ) ∗ n u m [ c [ v ] ]
具体细节看代码
#include <bits/stdc++.h>
using namespace std;
#define int long long
//typedef long long ll;
typedef pair<int,int> pii;
#define x first
#define y second
#define pb push_back
#define inf 1e18
#define IOS std::ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
#define fer(i,a,b) for(int i=a;i<=b;i++)
#define der(i,a,b) for(int i=a;i>=b;i--)
const int maxn=1e5+10;
const int mod=1e9+7;
/*ll qsm(int a,int b)
{ll res=1; while(b){ if(b&1) res=res*a%mod; a=a*a%mod; b>>=1; } return res;}
*/
const int N=2e5+10;
const int M=2e5+10;
int ans[N];
int dr[4][2]={{-1,0},{1,0},{0,-1},{0,1}};
int n,k;
int a[N];
int num[N];
int sum[N];
int h[N],e[M],ne[M],cnt;
int d[N];
int son[N];
int siz[N];
int now;
void add(int a,int b)
{
e[cnt]=b,ne[cnt]=h[a],h[a]=cnt++;
}
void dfs(int u,int fa)
{
d[u]=d[fa]+1;
siz[u]=1;
// int ma=-1;
for(int i=h[u];~i;i=ne[i])
{
int v=e[i];
if(v==fa)continue;
dfs(v,u);
siz[u]+=siz[v];
/* if(siz[v]>ma)
{
son[u]=v;
ma=siz[v];
}*/
if(siz[son[u]] < siz[v]) son[u] = v;
}
}
void update(int u,int fa)
{
sum[a[u]]+=d[u];
num[a[u]]++;
for(int i=h[u];~i;i=ne[i])
{
int v=e[i];
if(v==fa||v==now)continue;
update(v,u);
}
}
void del(int u,int fa)
{
num[a[u]]--;
sum[a[u]]-=d[u];
for(int i=h[u];~i;i=ne[i])
{
int v=e[i];
if(v==fa||v==now)continue;
del(v,u);
}
}
void cal(int u,int fa,int f)
{
for(int i=h[u];~i;i=ne[i])
{
int v=e[i];
if(v==fa||v==now)continue;
ans[f]+=sum[a[v]]-d[f]*num[a[v]]+(d[v]-d[f])*num[a[v]];
cal(v,u,f);
if(u==f)update(v,u);
}
}
void dsu(int u,int fa,bool f)
{
for(int i=h[u];~i;i=ne[i])
{
int v=e[i];
if(v==fa||v==son[u])continue;
dsu(v,u,false);
}
if(son[u]) {
dsu(son[u], u, true);
now = son[u];
}
cal(u, fa, u);
sum[a[u]] += d[u], num[a[u]] ++;
now = 0;
ans[fa] += sum[a[fa]] - d[fa] * num[a[fa]];
if(!f) del(u, fa);
}
void get(int u,int fa)
{
for(int i=h[u];~i;i=ne[i])
{
int v=e[i];
if(v==fa)continue;
get(v,u);
ans[u]+=ans[v];
}
}
void solve()
{
cin>>n;
memset(h,-1,sizeof(h));
fer(i,1,n)cin>>a[i];
fer(i,1,n-1)
{
int a,b;
cin>>a>>b;
add(a,b);
add(b,a);
}
dfs(1,0);
dsu(1,0,true);
get(1,0);
fer(i,1,n-1)
cout<<ans[i]<<" ";
cout<<ans[n]<<endl;
}
signed main()
{
IOS;
int _=1;
//cin>>_;
while(_--) solve();
return 0;
}