单独考虑每一种颜色,答案就是对于每种颜色至少经过一次这种的路径条数之和。反过来思考只需要求有多少条路径没有经过这种颜色即可。直接做可以采用虚树的思想(不用真正建出来),对每种颜色的点按照 dfs 序列排个序,就能求出这些点把原来的树划分成的块的大小。这个过程实际上可以直接一次 dfs 求出。
#include<stdio.h>
#include<vector>
#include<string.h>
#define N 200005
using namespace std;
typedef long long ll;
ll col[N],size[N];
ll mark[N];
vector<ll>vt[N];
ll sum[N];
ll ans=0;
void dfs(ll u,ll fa)
{
ll all=0;
size[u]=1;
for(ll i=0;i<vt[u].size();i++)
{
ll to=vt[u][i];
if(to==fa)continue;
ll s=sum[col[u]];
dfs(to,u);
size[u]+=size[to];
ll step=sum[col[u]]-s;
all+=step;
ans+=(size[to]-step)*(size[to]-step-1)/2;
}
sum[col[u]]+=-all+size[u];
}
int main()
{
ll n;
ll cas=1;
while(~scanf("%lld",&n))
{
memset(mark,0,sizeof(mark));
memset(size,0,sizeof(size));
memset(sum,0,sizeof(sum));
for(ll i=0;i<N;i++)vt[i].clear();
ll gg=0;
ans=0;
for(ll i=1;i<=n;i++)
{
scanf("%lld",&col[i]);
gg+=mark[col[i]]^1;
mark[col[i]]=1;
}
for(ll i=0;i<n-1;i++)
{
ll u,v;
scanf("%lld%lld",&u,&v);
vt[u].push_back(v);
vt[v].push_back(u);
}
dfs(1,-1);
ll ANS=n*(n-1)*gg/2;
for(ll i=1;i<=n;i++)
{
if(gg==col[1]||!mark[i])continue;
ans+=(n-sum[i])*(n-sum[i]-1)/2;
}
printf("Case #%lld: %lld\n",cas++,ANS-ans);
}
}