#include <bits/stdc++.h>
using namespace std;
typedef long long lint;
typedef long long LL;
const lint maxn = 200000 + 5;
const lint maxm = 400000 + 5;
const lint inf = 0x3f3f3f3f3f;
lint vis[maxn],g[maxn],sz[maxn],root,sum,c[maxn];
vector<lint> ve[maxn];
void add_edge( lint x,lint y ){
ve[x].push_back(y);
}
lint sz_sum,color_cnt,pre_color_sum,color[maxn],in_st[maxn],ans[maxn];
void dfs_cnt( lint x,lint f,lint flag ){
sz[x]=1;
if( !in_st[c[x] ] ){
color_cnt++;
pre_color_sum -= color[ c[x] ];
}
in_st[ c[x] ]++;
if( !in_st[c[root] ] )
ans[x] += pre_color_sum-color[ c[root] ] + (color_cnt+1) *( sz_sum+flag );
else ans[x] += pre_color_sum + color_cnt * (sz_sum+flag );
for( lint i = 0;i < ve[x].size();i++ ){
lint y = ve[x][i];
if( vis[y] || y==f ) continue;
dfs_cnt( y,x,flag );
sz[x] +=sz[y];
}
in_st[ c[x] ]--;
if(!in_st[ c[x] ] ){
color_cnt--;
pre_color_sum += color[ c[x] ];
}
}
void dfs_init( lint x,lint f ){
if(!in_st[c[x]] ){
color[ c[x] ] += sz[x];
pre_color_sum += sz[x];
}
in_st[ c[x] ]++;
for( lint i = 0;i < ve[x].size();i++ ){
lint y = ve[x][i];
if( y == f || vis[y] ) continue;
dfs_init(y,x);
}
in_st[c[x]]--;
}
void dfs_init2( lint x,lint f ){
color[c[x] ]=0;
for( lint i = 0;i < ve[x].size();i++ ){
lint y = ve[x][i];
if( vis[y] || y == f ) continue;
dfs_init2(y,x);
}
}
void get_root(lint x,lint f)
{
sz[x]=1;g[x]=0;
for(int i = 0; i < ve[x].size() ;i++){
lint y = ve[x][i];
if( vis[y] || y == f ) continue;
get_root(y,x);
g[x]=max(g[x],sz[y]);
sz[x]+=sz[y];
}
g[x]=max(g[x],sum-sz[x]);
if(g[x]<g[root])root=x;
}
void dfs_div( lint x ){
vis[x] = 1;
sz_sum = 0;
sz[x]=1;
pre_color_sum=0;
for( lint i = 0;i < ve[x].size();i++ ){
lint y = ve[x][i];
if(vis[y]) continue;
dfs_cnt(y,x,1);
sz[x]+=sz[y];
dfs_init( y,x );
sz_sum += sz[y];
}
ans[x] += pre_color_sum +sz[x]-1 -color[ c[x] ] ;
sz_sum = 0;
pre_color_sum=0;
dfs_init2(x,0);
for( lint i =ve[x].size()-1;i >= 0;i-- ){
lint y= ve[x][i];
if(vis[y]) continue;
dfs_cnt(y,x,0);
dfs_init(y,x);
sz_sum += sz[y];
}
dfs_init2(x,0);
for( lint i = 0; i < ve[x].size() ;i++ ){
lint y = ve[x][i];
if( vis[y] ) continue;
root = 0;sum = sz[y];
get_root( y,x );
dfs_div( root );
}
}
int main() {
lint n;
scanf("%lld",&n);
for( lint i = 1;i <= n;i++ ) scanf("%lld",&c[i]);
for( lint x,y,i = 1;i <= n-1;i++ ){
scanf("%lld%lld",&x,&y);
add_edge(x,y);add_edge(y,x);
}
g[0]=n;sum = n;
get_root(1,0);
dfs_div(root);
for( lint i = 1;i <= n;i++ ){
printf("%lld\n",ans[i]+1);
}
return 0;
}
这是树上差分的80分代码,实在看不出来哪错了QAQ
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef LL lint;
const lint maxn = 1000005;
const lint maxm = 2000005;
lint color[maxn],c[maxn];
lint he[maxn],ver[maxm],ne[maxm],tot,val[maxn],n;
void init(){
tot = 1;
}
void add( lint x,lint y ){
ver[++tot] = y;
ne[tot] = he[x];
he[x]=tot;
};
lint sz[maxn];
void dfs( lint x,lint f ){
sz[x]=1;
for( lint cure = he[x];cure;cure = ne[cure] ){
lint y = ver[cure];
if( y == f ) continue;
lint pre = color[ c[x] ];
dfs(y,x);
sz[x] += sz[y];
val[y] = sz[y] - (color[ c[x] ] - pre);
color[ c[x] ] += sz[y] - ( color[ c[x] ]-pre ) ;
}
color[ c[x] ]++;
}
lint color_val[maxn],v[maxn];
void dfs_pre( lint x,lint f ){
v[x] += val[x];
v[x] -= color_val[ c[x] ];
lint pre = color_val[ c[f] ];
color_val[ c[f] ] = val[x];
for( lint cure = he[x];cure;cure = ne[cure] ){
lint y = ver[cure];
if( y == f ) continue;
dfs_pre(y,x);
}
color_val[ c[f] ] = pre;
}
lint cur = 0;
lint res[maxn];
void dfs_cal( lint x,lint f ){
lint pre = cur;
cur += v[x];
res[x]=cur;
for( lint cure = he[x];cure;cure = ne[cure] ){
lint y = ver[cure];
if( y == f ) continue;
dfs_cal(y,x);
}
cur = pre;
}
lint in_st[maxn],vis[maxn];
void dfs_pre2( lint x,lint f ){
if( !in_st[ c[x] ] ){
v[ x ] -= n-color[ c[x] ];
}
in_st[ c[x] ]++;
for( lint cure = he[x];cure;cure = ne[cure] ){
lint y = ver[cure];
if( y == f ) continue;
dfs_pre2(y,x);
}
in_st[ c[x] ]--;
}
int main(){
lint color_cnt = 0;
scanf("%lld",&n);
init();
for( lint i = 1;i <= n;i++ ) {
scanf("%lld",&c[i]);
if( !vis[ c[i] ] ){
color_cnt++;
vis[ c[i] ]=1;
}
}
for( lint x,y,i = 1;i <=n-1;i++ ){
scanf("%lld%lld",&x,&y);
add(x,y);add(y,x);
}
dfs(1,0);
in_st[ c[1] ] = 1;
for( lint i = 1;i < maxn;i++ ) {
if( vis[i] )
cur += n-color[i];
}
dfs_pre2( 1,0 );
dfs_pre(1,0);
dfs_cal( 1,0 );
for( lint i = 1;i <= n;i++ ){
printf("%lld\n", n*color_cnt-res[i] );
}
return 0;
}