题目链接
思路:直接做好像有点困难… 那考虑点分治。
我们把路径分成经过分治重心的和不经过分治重心的。
那么,我们在递归处理某个重心的时候,就可以算出所有点的在此分治重心下的答案。
具体实现:考虑x色是根到当前y位置第一次出现的x色,那么这个x就能提供sz_y(当前分治重心下以y为根的子树大小)的贡献给其他跨过分治重心的点。我们用数组tmp来记录某个颜色的总贡献。所有颜色的总贡献用res表示。我们在计算某个点的答案时,要去掉根到当前点的颜色所产生的贡献。这在dfs过程中可以统计出来,需要注意的是算贡献的过程中要分清被贡献的对象。正反都求一边就能算出题目要求的答案了。
#include<bits/stdc++.h>
#define fi first
#define se second
#define pb push_back
#define mp make_pair
#define LL long long
using namespace std;
const int N = 1e5 + 11;
int n, m;
vector<pair<int,int> >v[N];
int q[N],col[N];
int rt, son[N], sz[N], vis[N];
void root(int now, int pre) { //找重心
sz[now] = 1;
son[now] = 0;
for (auto k : v[now]) {
if (vis[k.fi] || k.fi == pre)continue;
root(k.fi, now);
sz[now] += sz[k.fi];
son[now] = max(son[now], sz[k.fi]);
}
son[now] = max(son[now], n - sz[now]);
if (!rt || son[rt] > son[now]) {
rt = now;
}
return ;
}
LL ans[N];
int vi[N],tmp[N],so[N];
LL tot;
vector<int>nx,all;
LL als;
void getdis(int now, int pre, int di,LL res,int ch) {
/*处理细节*/
int st=0;
LL dis=0;
assert(res>=0);
if(!vi[col[now]]){
di++;
so[col[now]]+=sz[now];
vi[col[now]]=1;
nx.pb(col[now]);
dis+=tmp[col[now]];
st=1;
}
assert((res-dis)>=0);
ans[now]+=als*(di+1)+res-dis;if(ch)ans[now]+=di+1;
tot+=di;
for(auto k:v[now]){
if(k.fi!=pre &&!vis[k.fi]){
getdis(k.fi,now,di,res-dis,ch);
}
}
if(st)vi[col[now]]=0;
}
LL res;
void get(int now, int pre) {
/*一些处理细节*/
for(auto k:all){
tmp[k]=0;
vi[k]=0;
so[k]=0;
}
all.clear();
vi[col[now]]=1;
all.pb(col[now]);
res=0;tot=0;als=0;
for (auto k : v[now]) {
if (k.fi == pre || vis[k.fi])continue;
getdis(k.fi, now, 0,res,1);
ans[now]+=tot+sz[k.fi];als+=sz[k.fi];
res+=tot;tot=0;
for(auto j:nx){
vi[j]=0;
all.pb(j);
tmp[j]+=so[j];
so[j]=0;
}
nx.clear();
/*统计一下答案*/
}
for(auto k:all){
tmp[k]=0;
vi[k]=0;
}
vi[col[now]]=1;
all.clear();
all.pb(col[now]);
//
reverse(v[now].begin(),v[now].end());
tot=0;res=0;als=0;
for (auto k : v[now]) {
if (k.fi == pre || vis[k.fi])continue;
getdis(k.fi, now, 0,res,0);
res+=tot;tot=0;als+=sz[k.fi];
for(auto j:nx){
vi[j]=0;
tmp[j]+=so[j];
all.pb(j);
so[j]=0;
}
nx.clear();
/*统计一下答案*/
}
}
void dfs(int now) {
vis[now] = 1;
get(now, 0);
for (auto k : v[now]) {
if (vis[k.fi])continue;
rt = 0;
n = sz[k.fi];
root(k.fi, 0);
root(rt,0);
dfs(rt);
}
return ;
}
int main() {
scanf("%d", &n);
for(int i=1;i<=n;i++)scanf("%d",col+i);
for (int i = 1; i < n; i++) {
int s, t;
scanf("%d%d", &s, &t);
v[s].pb({t, 0});
v[t].pb({s, 0});
}
int pr_n=n;
root(1, 0);
root(rt, 0);
dfs(rt);
for(int i=1;i<=pr_n;i++)printf("%lld\n",ans[i]+1);
return 0;
}