There is a tree with nn nodes, each of which
has a type of color represented by an integer,
where the color of node ii is cici.
The path between each two different nodes is
unique, of which we define the value as the
number of different colors appearing in it.
Calculate the sum of values of all paths on the
tree that has n(n−1)2n(n−1)2 paths in total.
Input
The input contains multiple test cases.
For each test case, the first line contains one
positive integers nn, indicating the number of
node. (2≤n≤200000)(2≤n≤200000)
Next line contains nn integers where the ii-th
integer represents cici, the color of node ii.
(1≤ci≤n)(1≤ci≤n)
Each of the next n−1n−1 lines contains two
positive integers x,yx,y (1≤x,y≤n,x≠y)
(1≤x,y≤n,x≠y), meaning an edge between
node xx and node yy.
It is guaranteed that these edges form a tree.
Output
For each test case, output ” Case #xx: yy” in
one line (without quotes), where xx indicates
the case number starting from 11 and yy
denotes the answer of corresponding case.
Sample Input
3
1 2 1
1 2
2 3
6
1 2 1 3 2 1
1 2
1 3
2 4
2 5
3 6
Sample Output
Case #1: 6
Case #2: 29
题意:
现在有一棵树,每个节点都是有颜色的,每两个点之间的路径是唯一的。如果这条路径中包含了某个颜色,
那么这条路就算是一条该颜色的路径。
现在给你 n 个点,告诉你每个点的颜色,让你求每个颜色有多少条路径,最后输出总和。
例如上图:
思路:
逆向思维,想求每种颜色的路径,那么就先假设每种颜色的路径为全部路径,然后去除掉不符合条件的路径即可。对于白色来说, 8 - 4 - 9 这三个点形成的路径一定不会有白色路径 ( 即不符合规则的路径 ) 。那么这些连通块怎么求呢。我们可以用size [ 4 ] - ( size [ 11 ] + size [ 12 ] ) 来求出中间连通块的大小。
设置几个变量
size[ ] 数组表示以该点为根节点的树大小;
sum[某颜色] 数组表示遍历到当前节点,以某颜色为根节点形成的树的大小总和;
add 为以该节点颜色的为根节点的一颗子树的大小;
addsum 表示该节点子树中以该节点颜色为根节点的子树大小之和;
PS:该方法有一个bug,因为我们求连通块的时候用的是包围的方法,而图中的 1 点只有下包围圈,上面没有封顶,所以 1 点的连通块大小求不出来,所以我们再进行一步计算特殊处理该区域。用 n - sum[ 非黑 ] 就求出来了。
int vis[maxn];
int n;
LL ans;
int cas=1;
vector<int> e[maxn];
LL path(int a) ///由点求路径数量的函数
{
return (LL)a*(LL)(a-1)/2;
}
void dfs(int u,int fa) ///利用深搜对树进行遍历
{
siz[u]=1;
int addsum=0;
for(auto &v:e[u])
{
if(v==fa)
continue;
int oldsum=sum[c[u]]; ///记录旧数据
dfs(v,u);
int add=sum[c[u]]-oldsum; ///新老数据之差 ( 遍历前后之差 ) 即为以该节点为根的树的大小
ans+=path((LL)(siz[v]-add)); ///ans 用来记录不符合规则的路径数量之和
addsum+=add;
siz[u]+=siz[v]; ///size [ ] 的更新
}
sum[c[u]]+=siz[u]-addsum; ///sum [ ] 的更新
}
void work()
{
ans=0;
memset(sum,0,sizeof(sum));
memset(siz,0,sizeof(siz));
memset(vis,0,sizeof(vis));
for(int i=1;i<=n;i++)
{
e[i].clear();
scanf("%d",&c[i]);
vis[c[i]]=1;
}
int num=0;
for(int i=1;i<=n;i++)
{
num+=(vis[i]>0);
}
int u,v;
for(int i=1;i<n;i++)
{
scanf("%d %d",&u,&v);
e[u].push_back(v);
e[v].push_back(u);
}
memset(vis,0,sizeof(vis));
dfs(1,-1);
vis[c[1]]=1;
for(int i=2;i<=n;i++) ///单独进行的特殊计算,用来求算法的bug处
{
if(!vis[c[i]])
{
ans+=path((LL)n-sum[c[i]]);
vis[c[i]]=1;
}
}
LL myans=path((LL)n)*(LL)num-ans; ///答案 = 总路径数量 - 不符合规则的路径数量之和
printf("Case #%d: %lld\n",cas++,myans);
}
int main()
{
while(~scanf("%d",&n))
{
work();
}
return 0;
}