脑子笨,看人家博客半天才理解,太强了orz
题意:一棵树上,每个节点有一个颜色,任意两个节点之间的距离为它们连线上各节点的颜色数目,问所有路径的和是多少
诈一看以为是分治,其实是树形dp....
思路:对某个颜色x,算出其在树上没有贡献的某一块部分,假设这一块的节点数为y,那么颜色x在这一块上对总答案“少贡献了”y*(y-1)/2 。如果每个颜色对所有路径都有贡献,那么总答案为 n*(n-1)/2*ant(颜色数目) ,减去少贡献的部分就行了
代码:
#include<iostream>
using namespace std;
#include <string.h>
#include <algorithm>
#include <vector>
#define ll long long
#define M 200005
ll sum[M],color[M],sz[M],vis[M],ans;
vector<ll> edg[M];
ll dfs(ll u,ll fa)
{
sz[u]=1;
ll i , ls=0 , end=edg[u].size();
for(i=0;i<end;i++)
{
ll v=edg[u][i];
if(v==fa)
continue;
ll s=sum[color[u]];
sz[u]+=dfs(v,u);
ll lass=sum[color[u]]-s;
ans += (sz[v]-lass)*(sz[v]-lass-1)/2;
ls +=sz[v]-lass;
}
sum[color[u]] += ls+1;
return sz[u];
}
int main()
{
ll n,i,j,k,cnt,t=1;
ll a1=1,a2=2;
while(cin>>n)
{
cnt=ans=0;
memset(sum,0,sizeof(sum));
memset(vis,0,sizeof(vis));
for(i=1;i<=n;i++)
{
// cin>>color[i];
scanf("%I64d",&color[i]);
if(!vis[color[i]])
cnt++;
vis[color[i]]=1;
edg[i].clear();
}
for(i=1;i<n;i++)
{
// cin>>j>>k;
scanf("%I64d %I64d",&j,&k);
edg[j].push_back(k);
edg[k].push_back(j);
}
printf("Case #%I64d: ",t++);
if(cnt==1)
// cout<<"Case #"<<t++<<": "<< <<endl;
printf("%I64d\n" , (n-a1)*n/a2);
else
{
dfs(1,-1);
for(i=1;i<=n;i++)
{
if(!vis[i]) continue;
ans+= (n-sum[i])*(n-sum[i]-a1)/a2;
}
//cout<<"Case #"<<t++<<": "<< <<endl;
printf("%I64d\n" , n*(n-a1)*cnt/a2 - ans);
}
}
return 0;
}