官方题解:单独考虑每一种颜色,答案就是对于每种颜色至少经过一次这种的路径条数之和。反过来思考只需要求有多少条路径没有经过这种颜色即可。直接做可以采用虚树的思想(不用真正建出来),对每种颜色的点按照 dfs 序列排个序,就能求出这些点把原来的树划分成的块的大小。这个过程实际上可以直接一次 dfs 求出。
令sum[i]=已经dfs遍历过的点中,以i颜色的点为根的子树的节点个数之和(不重复计数)
当遍历到u点时,对u的儿子v进行dfs,dfs后,sum[c[u]]增长的值为increase,以v为根的子树的大小为size[v]
那以u为根的子树,不包含颜色c[u]且包含v点的联通块大小为size[v]−increase
C2size[v]−increase就是u为根的子树,不包含颜色c[u]且包含v点的联通块的没有经过颜色c[u]的路径数量
因为遍历完u后,sum[c[u]]应该比遍历u前多了size[u],而sum[c[u]]在每次dfs儿子时已经递增了一部分
当遍历完u的所有儿子时,sum[c[u]]+=(size[u]−dfs每个儿子后的increase之和),以维护sum[c[u]]
例如:
如图,size[v1]=3,遍历v1前,sum[白]=0
遍历完v1后,sum[白]=1
increase1=1,那以u为根的子树,不包含白色且包含v1点的联通块大小为size[v1]−increase1=2
遍历完v2后,sum[白]=3
increase2=2,那以u为根的子树,不包含白色且包含v2点的联通块大小为size[v2]−increase2=2
u的儿子遍历完后,sum[白]+=size[u]−(increase1+increase2)
方便起见,设根为1,显然当c[i]!=c[1]时,包含点1,而不含c[i]颜色的点的路径被忽略了
而这些路径所在联通块大小显然就是n−sum[c[i]],另外处理一下就好了
#include<stdio.h>
#include<iostream>
#include<stdlib.h>
#include<algorithm>
#include<vector>
#include<string.h>
#include<string>
#include<math.h>
#include<memory.h>
#define ll long long
#define pii pair<int,int>
#define pll pair<ll,ll>
#define MEM(a,x) memset(a,x,sizeof(a))
#define lowbit(x) ((x)&-(x))
using namespace std;
//const int inf=0x3f3f3f3f;
const int MOD = 1e9+7;
const int N = 200000 + 50;
int c[N],size[N],sum[N];
int cNum[N];//cNum[i]=j 颜色为i的点有j个
bool visC[N];//是否遍历过颜色i
ll sub;
vector<int>G[N];
inline ll f(ll x){
return x*(x-1)/2;
}
void dfs(int u,int pre){
int increaseSum=0;
size[u]=1;
for(int v:G[u]){
if(pre==v){
continue;
}
int preSum=sum[c[u]];
dfs(v,u);
int increase=sum[c[u]]-preSum;
sub+=f(size[v]-increase);
increaseSum+=increase;
size[u]+=size[v];
}
sum[c[u]]+=size[u]-increaseSum;
}
ll slove(int n){
sub=0;
MEM(sum,0);
dfs(1,-1);
int cSum=0;//颜色数量
for(int i=1;i<=n;++i){
cSum+=(cNum[i]>0);
}
MEM(visC,0);
visC[c[1]]=1;
for(int i=2;i<=n;++i){
if(!visC[c[i]]){
sub+=f(n-sum[c[i]]);
visC[c[i]]=1;
}
}
return f(n)*cSum-sub;
}
int main()
{
//freopen("/home/lu/code/r.txt","r",stdin);
int n,T=1;
while(~scanf("%d",&n)){
MEM(cNum,0);
for(int i=1;i<=n;++i){
G[i].clear();
scanf("%d",&c[i]);
++cNum[c[i]];
}
for(int i=0;i<n-1;++i){
int u,v;
scanf("%d%d",&u,&v);
G[u].push_back(v);
G[v].push_back(u);
}
printf("Case #%d: %lld\n",T++,slove(n));
}
return 0;
}