Description
给出一棵树,每个点有固定的颜色,问树上所有简单路径上不同颜色数之和
Input
多组用例,每组用例首先输入一整数n表示树上节点个数,之后输入n个整数c[i]表示第i个节点的颜色,之后n-1行每行输入两个整数u,v表示一条树边,以文件尾结束输入(2<=n<=2e5,1<=c[i]<=n)
Output
对于每组用例,输出树上所有简单路径上不同颜色数之和
Sample Input
3
1 2 1
1 2
2 3
6
1 2 1 3 2 1
1 2
1 3
2 4
2 5
3 6
Sample Output
Case #1: 6
Case #2: 29
Solution
对于一种颜色i,假设所有简单路径都会经过该种颜色的点,然后去找不合法的路径,把颜色为i的点去掉树就变成若干连通块,每个连通块内任意两点之间的简单路径都不会经过i颜色的点,且任意两个连通块中的点之间的路径必然经过i颜色的点,所以问题在于如何统计每个连通块中的点数,考虑一个i颜色的点x的儿子y所在连通块中点数,那么num等于以y为根的子树中的点数减去y的子树中那些在以i颜色为根的子树中的点
第一种方法是用dfs序快速找不合法点:
为了快速找到这些不合法点我们首先求出dfs序(对于一个点a,求出L[a]表示a的dfs序,R[a]表示以a为根的子树中dfs序最大值,那么[L[a],R[a]]表示以a为根的子树中所有点的dfs序构成的区间),然后用dfs序去快速找点,例如要找y的子树中第一个i颜色的点,我们在所有i颜色点中找到第一个dfs序大于y的dfs序的点z1,那么以z1为根的子树中所有点不在y所在连通块中,之后找下一个不在z1的子树中但dfs序最小的i颜色点,所以找dfs序大于R[z1]的点z2即可,以此类推一直找到超出以y为根所在子树为止,减掉这些不合法点即得到y所在连通块点数num,从答案中减掉num*(num-1)/2即可,注意到可能根节点1所在连通块不会被统计到,所以添加一个超级根0,在统计i颜色的答案时赋予0节点i颜色即可,时间复杂度O(nlogn)
第二种方法是用虚树的思想,在统计一种颜色点的答案时只需要维护该种颜色点的信息然后更新答案,dfs到一个点就只更新及维护该种颜色点的信息,用sum[i]表示被i颜色节点支配的节点个数,和第一种方法类似,例如当前dfs到u点,我们想要得到u节点的儿子节点v所在连通块中的点数num,其实只要记录当前被color[u]支配的节点个数s1以及搜完v的子树后被color[u]支配的节点个数s2,s2-s1即为以v为根的子树中被color[u]支配的节点个数,用size[v]减去s2-s1即为num,从答案中减去num*(num-1)/2即可,注意在搜完以u为根的子树后还要更新sum[color[u]],需要加上被u节点支配的节点个数,这个值很好求,只要在搜u的儿子v时统计一下所有以v为根的子树中被color[u]支配的节点个数(即所有s2-s1的和),然后用size[u]减去这些点即为只被u支配的点的个数,时间复杂度O(n)
Code1
#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<vector>
#include<queue>
#include<map>
#include<set>
#include<ctime>
using namespace std;
typedef long long ll;
#define INF 0x3f3f3f3f
#define maxn 222222
int res=1,n,index,L[maxn],R[maxn],S[maxn],F[maxn];
vector<int>c[maxn],e[maxn];
void dfs(int u,int fa)
{
L[u]=++index;
S[u]=1,F[u]=fa;
for(int v:e[u])
{
if(v==F[u])continue;
F[v]=u;
dfs(v,u);
S[u]+=S[v];
}
R[u]=index;
}
int cmp(int x,int y)
{
return L[x]<L[y];
}
int main()
{
while(~scanf("%d",&n))
{
for(int i=0;i<=n;i++)c[i].clear(),e[i].clear();
for(int i=1;i<=n;i++)
{
int temp;
scanf("%d",&temp);
c[temp].push_back(i);
}
for(int i=1;i<n;i++)
{
int u,v;
scanf("%d%d",&u,&v);
e[u].push_back(v),e[v].push_back(u);
}
e[0].push_back(1);
index=0;
dfs(0,0);
ll ans=1ll*n*n*(n-1)/2;
for(int i=1;i<=n;i++)
{
if(c[i].size()==0)
{
ans-=1ll*n*(n-1)/2;
continue;
}
c[i].push_back(0);
sort(c[i].begin(),c[i].end(),cmp);
for(int x:c[i])
for(int y:e[x])
{
if(y==F[x])continue;
int num=S[y],k=L[y];
while(1)
{
L[n+1]=k;
auto it=lower_bound(c[i].begin(),c[i].end(),n+1,cmp);
if(it==c[i].end()||L[*it]>R[y])break;
num-=S[*it],k=R[*it]+1;
}
ans-=1ll*num*(num-1)/2;
}
}
printf("Case #%d: %I64d\n",res++,ans);
}
return 0;
}
Code2
#include<cstdio>
#include<cstring>
#include<vector>
using namespace std;
typedef long long ll;
const int maxn=200001;
int n,C[maxn],Size[maxn],Sum[maxn],Vis[maxn];
ll ans;
struct node
{
int to,next;
}e[2*maxn];
int head[maxn],tot;
void init()
{
memset(head,-1,sizeof(head));
tot=0;
}
void add(int u,int v)
{
e[tot].to=v,e[tot].next=head[u],head[u]=tot++;
}
void dfs(int u,int fa)
{
Size[u]=1;
int s=Sum[C[u]],cnt=0;
for(int i=head[u];~i;i=e[i].next)
{
int v=e[i].to;
if(v==fa)continue;
dfs(v,u);
Size[u]+=Size[v];
int res=Size[v]-(Sum[C[u]]-s);
cnt+=Sum[C[u]]-s;
s=Sum[C[u]];
ans-=(ll)res*(res-1)/2;
}
Sum[C[u]]+=Size[u]-cnt;
}
int main()
{
int Case=1;
while(~scanf("%d",&n))
{
memset(Vis,0,sizeof(Vis));
memset(Sum,0,sizeof(Sum));
for(int i=1;i<=n;i++)scanf("%d",&C[i]),Vis[C[i]]=1;
init();
for(int i=1;i<n;i++)
{
int u,v;
scanf("%d%d",&u,&v);
add(u,v),add(v,u);
}
int num=0;//不同颜色节点个数
for(int i=1;i<=n;i++)num+=Vis[i];
ans=(ll)num*n*(n-1)/2;
dfs(1,0);
for(int i=1;i<=n;i++)
if(i!=C[1]&&Vis[i])
{
int res=n-Sum[i];
ans-=(ll)res*(res-1)/2;
}
printf("Case #%d: %I64d\n",Case++,ans);
}
return 0;
}