题目大意
给定一棵
n
个节点的树,每个点有一个颜色种类
对于每一个点
x
,你需要统计从
1≤n≤3×105,0≤ci≤n
题目分治
首先这题虚树肯定可以做,这里不讲。
考虑使用点分治,先不考虑有多种颜色。假设我只想统计出现过某一种颜色的路径总数。
对于分治重心
c
,在分治过程中做到点
∙
如果
x
到
∙
如果
x
到
现在考虑将其扩展到多种颜色上。
对于分治重心
c
,在分治过程中做到点
∙
设其到
c
路径上颜色种类数为
∙
对于那些没有出现过的颜色种类,我们考虑先预处理不包含颜色
i
的到
时间复杂度
O(nlogn)
。
代码实现
#include <iostream>
#include <cstdio>
#include <cctype>
using namespace std;
typedef long long LL;
int read()
{
int x=0,f=1;
char ch=getchar();
while (!isdigit(ch)) f=ch=='-'?-1:f,ch=getchar();
while (isdigit(ch)) x=x*10+ch-'0',ch=getchar();
return x*f;
}
int buf[30];
void write(LL x)
{
if (x<0) putchar('-'),x=-x;
for (;x;x/=10) buf[++buf[0]]=x%10;
if (!buf[0]) buf[++buf[0]]=0;
for (;buf[0];putchar('0'+buf[buf[0]--]));
}
const int N=300050;
const int E=N<<1;
int last[N],fa[N],size[N],col[N],f[N],ext[N],que[N];
int nxt[E],tov[E];
bool vis[N];
LL ans[N];
int n,tot,head,tail,cur;
LL sum;
void insert(int x,int y){tov[++tot]=y,nxt[tot]=last[x],last[x]=tot;}
int core(int og)
{
int i,x,y,ret,rets=n,tmp;
for (head=0,fa[que[tail=1]=og]=0;head<tail;)
for (size[x=que[++head]]=1,i=last[x];i;i=nxt[i])
if ((y=tov[i])!=fa[x]&&!vis[y])
fa[que[++tail]=y]=x;
for (head=tail;head>1;--head) size[fa[que[head]]]+=size[que[head]];
for (head=1;head<=tail;++head)
{
for (tmp=size[og]-size[x=que[head]],i=last[x];i;i=nxt[i])
if ((y=tov[i])!=fa[x]&&!vis[y]) tmp=max(tmp,size[y]);
if (tmp<rets) ret=x,rets=tmp;
}
return ret;
}
void dfs(int x)
{
size[x]=1;
for (int i=last[x],y;i;i=nxt[i])
if ((y=tov[i])!=fa[x]&&!vis[y])
fa[y]=x,dfs(y),size[x]+=size[y];
}
void count(int x,int *f,int sig)
{
if (!ext[col[x]]++) f[col[x]]+=sig*size[x],sum+=sig*size[x],++cur;
for (int i=last[x],y;i;i=nxt[i])
if ((y=tov[i])!=fa[x]&&!vis[y]) count(y,f,sig);
if (!--ext[col[x]]) --cur;
}
void calc(int x,int siz,int c)
{
if (!ext[col[x]]++) sum-=f[col[x]],++cur;
ans[x]+=1ll*siz*cur+sum,ans[c]+=cur;
for (int i=last[x],y;i;i=nxt[i])
if ((y=tov[i])!=fa[x]&&!vis[y]) calc(y,siz,c);
if (!--ext[col[x]]) sum+=f[col[x]],--cur;
}
void solve(int x)
{
int c=core(x);
++ans[c],size[c]=1;
for (int i=last[c],y;i;i=nxt[i])
if (!vis[y=tov[i]]) fa[y]=c,dfs(y),size[c]+=size[y];
for (int i=last[c],y;i;i=nxt[i])
if (!vis[y=tov[i]]) count(y,f,1);
for (int i=last[c],y;i;i=nxt[i])
if (!vis[y=tov[i]]) count(y,f,-1),++ext[col[c]],sum-=f[col[c]],cur=1,calc(y,size[c]-size[y],c),--ext[col[c]],sum+=f[col[c]],cur=0,count(y,f,1);
for (int i=last[c],y;i;i=nxt[i])
if (!vis[y=tov[i]]) count(y,f,-1);
vis[c]=1;
for (int i=last[c],y;i;i=nxt[i])
if (!vis[y=tov[i]]) solve(y);
}
int main()
{
freopen("mushroom.in","r",stdin),freopen("mushroom.out","w",stdout);
n=read();
for (int i=1;i<=n;++i) col[i]=read();
for (int i=1,x,y;i<n;++i) x=read(),y=read(),insert(x,y),insert(y,x);
solve(1);
for (int i=1;i<=n;++i) write(ans[i]),putchar('\n');
fclose(stdin),fclose(stdout);
return 0;
}