题意:给一颗树,定义树上路径u到v的价值为路径上不同颜色的节点的数量,问所有路径的总价值为多少
题解:根据题意,所需要求解的价值可以转换为,假设每条路径上包含所有的颜色,也就是n*(n-1)/2*(总颜色个数),然后减去每个颜色对不存在这个颜色的的所有路径的贡献。
dfs树上的每一个节点,对于当前节点,我们对每一个子树,用sum数组记录每一个颜色的截断值,于是这个子树的size-父节点颜色的截断值 来表示该子树与父节点颜色不同的连通块的总个数。
AC代码:
#include<stdio.h>
#include<vector>
#include<string.h>
#define N 200005
using namespace std;
typedef long long ll;
ll col[N],size[N];
ll mark[N];
vector<ll>vt[N];
ll sum[N];
ll ans=0;
void dfs(ll u,ll fa)
{
ll all=0;
size[u]=1;
for(ll i=0;i<vt[u].size();i++)
{
ll to=vt[u][i];
if(to==fa)continue;
ll s=sum[col[u]];
dfs(to,u);
size[u]+=size[to];
ll step=sum[col[u]]-s;
all+=step;
ans+=(size[to]-step)*(size[to]-step-1)/2;
}
sum[col[u]]+=-all+size[u];
}
int main()
{
ll n;
ll cas=1;
while(~scanf("%lld",&n))
{
memset(mark,0,sizeof(mark));
memset(size,0,sizeof(size));
memset(sum,0,sizeof(sum));
for(ll i=0;i<N;i++)vt[i].clear();
ll gg=0;
ans=0;
for(ll i=1;i<=n;i++)
{
scanf("%lld",&col[i]);
gg+=mark[col[i]]^1;
mark[col[i]]=1;
}
for(ll i=0;i<n-1;i++)
{
ll u,v;
scanf("%lld%lld",&u,&v);
vt[u].push_back(v);
vt[v].push_back(u);
}
dfs(1,-1);
ll ANS=n*(n-1)*gg/2;
for(ll i=1;i<=n;i++)
{
if(gg==col[1]||!mark[i])continue;
ans+=(n-sum[i])*(n-sum[i]-1)/2;
}
printf("Case #%lld: %lld\n",cas++,ANS-ans);
}
}