题目链接
思路:
首先求出所有的2和1的个数并求出他们一共可以结合成多少对,然后再输入关系对,用并查集维护关系,再一点一点减去相关联后减少的对数即可。
代码:
#include<bits/stdc++.h>
#define int long long
#define IOS ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
const int N=1e5+5;
const int M=2e4+5;
const double eps=1e-8;
const int mod=1e9+7;
const int inf=0x7fffffff;
const double pi=3.1415926;
using namespace std;
int w[N],cnt[3],p[N],p1[N],p2[N];
int Find(int x)
{
if(p[x]!=x)
{
p[x]=Find(p[x]);
}
return p[x];
}
int C_2(int n)
{
if(n<2)
{
return 0;
}
return (n-1)*n/2%mod;
}
int C_3(int n)
{
if(n<3)
{
return 0;
}
return n*(n-1)*(n-2)/6%mod;
}
signed main()
{
IOS;
int t;
cin>>t;
while(t--)
{
int n;
cin>>n;
memset(cnt,0,sizeof cnt);
for(int i=1;i<=n;i++)
{
cin>>w[i];
p[i]=i;
if(w[i]==1)
{
p1[i]=1;
p2[i]=0;
cnt[1]++;
}
else
{
p2[i]=1;
p1[i]=0;
cnt[2]++;
}
}
int ans=(C_3(cnt[2])+C_2(cnt[2])*cnt[1]%mod)%mod;
cout<<ans<<endl;
for(int i=0;i<n-1;i++)
{
int k=0,u,v;
cin>>u>>v;
int pu=Find(u), pv=Find(v);
k=(k+p1[pu]*p2[pv]*(cnt[2]-p2[pu]-p2[pv]))%mod;
k=(k+p2[pu]*p1[pv]*(cnt[2]-p2[pu]-p2[pv]))%mod;
k=(k+p2[pu]*p2[pv]*(cnt[2]-p2[pu]-p2[pv]))%mod;
k=(k+p2[pu]*p2[pv]*(cnt[1]-p1[pu]-p1[pv]))%mod;
ans=(ans-k+mod)%mod;
cout<<ans<<endl;
p[pv]=pu;
p1[pu]+=p1[pv];
p2[pu]+=p2[pv];
p1[pv]=0,p2[pv]=0;
}
}
return 0;
}