题意:给你一颗N个节点的树,每个节点有一个颜色,让你求式子的值。
题解:挂着点分治里面的练习题,其他的与点分治差不多,只是中间的维护不同,根据题意它两个点只间的值是路径是不同颜色的数量,划分分治中点后,路径被分为经过中点的和没有经过中点的路径,没有经过中点的路径是子问题,处理相同,考虑经过中点的路径怎么统计答案,首先通过一次dfs是可以统计出经过当前分治中点的路径的不同颜色和,以及每个子节点的贡献,同时每颗子树颜色出现是它掌管的子节点数量也是可以求出的,和每种颜色掌控的总节点和,那么首先通过分治中点,两个点实际上就是两个子树之间的,那么对于一个子树中的节点会增加的贡献就是总的贡献减去root对应儿子子树的贡献,那么如果对于一个颜色出现在了当前子树中,那么它新增加的贡献就是其他子树中没有被这种颜色掌控的节点数,通过第一个dfs维护的颜色掌控的总节点和,我们每次统计一颗子树时,先通过一次dfs将这颗子树的对总颜色和的影响删去,在进入统计答案,然后在通过一次dfs加上,最后通过dfs将数组清0,直接memset复杂度退化为平方级别。总的时间复杂度是nlong级别。
后记:菜鸡想了一下午,终于在晚上想出来了,写发博客纪念一下。
AC代码:
#include<stdio.h>
#include<vector>
#include<string.h>
#include<algorithm>
#include<iostream>
#include<math.h>
#include<queue>
#include<set>
#include<map>
#include<stack>
#define ll long long
using namespace std;
int read(){
char c;int x=0,y=1;while(c=getchar(),(c<'0'||c>'9')&&c!='-');
if(c=='-') y=-1;else x=c-'0';while(c=getchar(),c>='0'&&c<='9')
x=x*10+c-'0';return x*y;
}
const int maxn=3e5+10;
const int mod=1e9+7;
const int inf=1e9;
int sz[maxn],ma[maxn],root,sum,n,m;
bool vis[maxn];
struct node{
int to,nex;
};
node side[maxn];
int head[maxn],tot;
int num[maxn],num1[maxn],ci[maxn],visc[maxn];
ll ans[maxn],di[maxn];
void add(int u,int v){
side[tot].to=v;
side[tot].nex=head[u];
head[u]=tot++;
}
void getroot(int now,int fa){
sz[now]=1;
ma[now]=0;
for(int i=head[now];i!=-1;i=side[i].nex){
int v=side[i].to;
if(v==fa||vis[v]) continue;
getroot(v,now);
sz[now]+=sz[v];
ma[now]=max(ma[now],sz[v]);
}
ma[now]=max(ma[now],sum-sz[now]);
if(!root||ma[now]<ma[root]) root=now;
}
void get_sz(int now,int fa){
sz[now]=1;
for(int i=head[now];i!=-1;i=side[i].nex){
int v=side[i].to;
if(v==fa||vis[v]) continue;
get_sz(v,now);
sz[now]+=sz[v];
}
}
void get_dis(int now,int fa,int nu){
if(!visc[ci[now]]) num[now]=sz[now],num1[ci[now]]+=sz[now],nu++;
visc[ci[now]]++;
di[now]=nu;
for(int i=head[now];i!=-1;i=side[i].nex){
int v=side[i].to;
if(v==fa||vis[v]) continue;
get_dis(v,now,nu);
di[now]+=di[v];
}
visc[ci[now]]--;
}
void modify(int now,int fa,int va){
if(num[now]) num1[ci[now]]+=va*num[now];
if(va==0) num[now]=num1[ci[now]]=0;
for(int i=head[now];i!=-1;i=side[i].nex){
int v=side[i].to;
if(v==fa||vis[v]) continue;
modify(v,now,va);
}
}
void get_ans(int now,int fa,ll va,int siz){
if(num[now])
va+=(siz-num1[ci[now]]);
ans[now]+=va;
for(int i=head[now];i!=-1;i=side[i].nex){
int v=side[i].to;
if(v==fa||vis[v]) continue;
get_ans(v,now,va,siz);
}
}
void work(int now){
get_dis(now,0,0);
ans[now]+=di[now];
for(int i=head[now];i!=-1;i=side[i].nex){
int v=side[i].to;
if(vis[v]) continue;
modify(v,now,-1);
get_ans(v,now,di[now]-di[v],sz[now]-sz[v]);
modify(v,now,1);
}
modify(now,0,0);
}
void solve(int now){
vis[now]=1;
get_sz(root,0);
work(now);
for(int i=head[now];i!=-1;i=side[i].nex){
int v=side[i].to;
if(vis[v]) continue;
sum=sz[v];
root=0;
getroot(v,now);
solve(root);
}
}
int main( ){
n=read();
memset(head,-1,sizeof(head));
tot=0;
for(int a=1;a<=n;a++) ci[a]=read();
for(int a=1;a<n;a++){
int u=read(),v=read();
add(u,v);
add(v,u);
}
sum=n;
root=tot=0;
ma[0]=n;
getroot(1,0);
solve(root);
for(int a=1;a<=n;a++) printf("%lld\n",ans[a]);
}