考虑统计不包含颜色 i i i 的路径条数。
发现等价于:将所有
c
[
u
]
=
i
c[u]=i
c[u]=i 的点删除后所有连通块的贡献。设
d
p
[
u
]
[
i
]
dp[u][i]
dp[u][i] 表示以
u
u
u 为根的子树,不包括颜色
i
i
i 的连通块的大小。根据定义不难得到:
d
p
[
u
]
[
i
]
=
{
∑
d
p
[
v
]
[
i
]
+1
c
[
u
]
≠
i
0
c
[
u
]
=
i
dp[u][i]= \begin{cases} \sum dp[v][i]\text{+1}&\text c[u]\ne i\\ \\ 0&\ c[u]=i \end{cases}
dp[u][i]=⎩⎪⎨⎪⎧∑dp[v][i]+10c[u]=i c[u]=i
现在只需对于节点
u
u
u 考虑
c
[
u
]
c[u]
c[u] 所产生的贡献。时间复杂度
O
(
n
2
)
O(n^2)
O(n2) 。
当然这个转移式可以优化,就是第二维不太好维护。
提供一个 O ( n l o g n ) O(nlogn) O(nlogn) 的线段树做法。首先把 c [ u ] = i c[u]=i c[u]=i 的点存下来,然后按 d e p [ u ] dep[u] dep[u] 从大到小排序,每次取出一个节点,遍历所有儿子,查询子树中剩余节点的个数,然后把整个子树删去。这样均摊操作次数 O ( n ) O(n) O(n) ,用线段树大力维护即可。
#include <bits/stdc++.h>
#define INF 0x3f3f3f3f
#define PII pair<int,int>
#define ll long long
#define f(x) 1ll*x*(x-1)/2
#define All(a) a.begin(),a.end()
using namespace std;
const int mx=2e5+5;
//qwq ...
int n,Case,c[mx],siz[mx],son[mx],num,dfn[mx],fa[mx],dep[mx];
struct SegmentTree{
int siz,tag,len;
}t[mx<<2];
vector<int> g[mx],g2[mx];
void dfs(int u,int fath) {
dep[u]=dep[fath]+1; fa[u]=fath; siz[u]=1; dfn[u]=++num;
for(auto v:g[u]) {
if(v==fath) continue;
dfs(v,u),siz[u]+=siz[v];
}
}
void PushUp(int p) {
t[p].siz=t[p<<1].siz+t[p<<1|1].siz;
}
void PushDown(int p) {
if(t[p].tag!=-1) return;
// puts("Ac");
t[p<<1].tag=t[p<<1|1].tag=-1;
t[p<<1].siz=t[p<<1].len;
t[p<<1|1].siz=t[p<<1|1].len;
t[p].tag=0;
}
void build(int p,int l,int r) {
t[p].siz=t[p].len=r-l+1; t[p].tag=0;
if(l==r) return;
int mid=l+r>>1;
build(p<<1,l,mid);
build(p<<1|1,mid+1,r);
}
void update(int p,int l,int r,int ql,int qr){
if(ql<=l&&r<=qr) {
t[p].siz=0; t[p].tag=1; return;
}
if(t[p].tag==1) return;
PushDown(p);
int mid=l+r>>1;
if(ql<=mid) update(p<<1,l,mid,ql,qr);
if(mid<qr) update(p<<1|1,mid+1,r,ql,qr);
PushUp(p);
}
int query(int p,int l,int r,int ql,int qr) {
if(ql<=l&&r<=qr) {
return t[p].siz;
}
if(t[p].tag==1) return 0;
PushDown(p);
int mid=l+r>>1,ans=0;
if(ql<=mid) ans+=query(p<<1,l,mid,ql,qr);
if(mid<qr) ans+=query(p<<1|1,mid+1,r,ql,qr);
return ans;
}
bool cmp(int x,int y) {
return dep[x]>dep[y];
}
signed main() {
// freopen("data.in","r",stdin);
while(scanf("%d",&n)!=EOF) {
ll res=0; num=0;
for(int i=1;i<=n;i++) scanf("%d",&c[i]);
for(int i=1;i<=n;i++) g[i].clear();
for(int i=1;i<n;i++) {
int u,v; scanf("%d%d",&u,&v);
g[u].push_back(v);
g[v].push_back(u);
}
dfs(1,0); build(1,1,n);
for(int i=1;i<=n;i++) g2[i].clear();
for(int i=1;i<=n;i++) g2[c[i]].push_back(i);
for(int i=1;i<=n;i++) {
if(!g2[i].size()) continue;
sort(All(g2[i]),cmp);
res+=f(n);
// build(1,1,n);
t[1].tag=-1; t[1].siz=n;
for(auto u:g2[i]) {
for(auto v:g[u]) {
if(v==fa[u]) continue;
int tmp=query(1,1,n,dfn[v],dfn[v]+siz[v]-1);
res-=f(tmp);
}
update(1,1,n,dfn[u],dfn[u]+siz[u]-1);
}
int tmp=t[1].siz;
res-=f(tmp);
// cout<<res<<endl;
}
printf("Case #%d: %lld\n",++Case,res);
}
}