1006题目链接
第i条边的边权是2^i,由2 ^1+2 ^2+2 ^3+…+2 ^(i-1)<2 ^i,故可以跑最小生成树,两两之间的距离便是最短路了。
计算方面,对于每条边,计算有几个0 1路径经过该边,边权乘以这个个数就可以了。
代码:
#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>
#include <iostream>
#include <queue>
#include <vector>
#include <map>
using namespace std;
#define int128 _int128;
#define reg register;
typedef long long ll;
typedef double db;
const int mod=1e9+7;
const int maxn=1e5+5;
const double eps=0.00000001;
struct Edge{
int u,w;
}e[maxn];
vector<Edge>G[maxn];
int a[maxn],pa[maxn],vis[maxn],s1,s2,c1[maxn],c0[maxn];
ll ans=0;
int t,n,m;
int find(int x){
if(x==pa[x]) return x;
return pa[x]=find(pa[x]);
}
void unite(int x,int y){
x=find(x),y=find(y);
if(x==y) return;
pa[x]=y;
}
ll pow2[2*maxn];
void dfs(int x){
vis[x]=1;
for(int i=0;i<G[x].size();i++){
int t=G[x][i].u;
if(vis[t]==0){
dfs(t);
c0[x]+=c0[t];
c1[x]+=c1[t];//计算该点的子树中有多少个0和1的点。
}
}
if(a[x]==0) c0[x]++;
else c1[x]++;
}
void Dfs(int x){
vis[x]=1;
for(register int i=0;i<G[x].size();i++){
int t=G[x][i].u;
if(vis[t]==0){
Dfs(t);
if(a[t]==0&&a[x]==1) ans=(ans+pow2[G[x][i].w]*c0[t]%mod*(s1-c1[t])%mod+pow2[G[x][i].w]*c1[t]%mod*(s2-c0[t])%mod)%mod;
else if(a[t]==0&&a[x]==0) ans=(ans+pow2[G[x][i].w]*c0[t]%mod*(s1-c1[t])%mod+pow2[G[x][i].w]*c1[t]%mod*(s2-c0[t])%mod)%mod;
else if(a[t]==1&&a[x]==0) ans=(ans+pow2[G[x][i].w]*c1[t]%mod*(s2-c0[t])%mod+pow2[G[x][i].w]*c0[t]%mod*(s1-c1[t])%mod)%mod;
else ans=(ans+pow2[G[x][i].w]*c0[t]%mod*(s1-c1[t])%mod+pow2[G[x][i].w]*c1[t]%mod*(s2-c0[t])%mod)%mod;
}
}
}
int main(){
cin>>t;
pow2[0]=1;
for(int i=1;i<=200000;i++){
pow2[i]=pow2[i-1]*2%mod;
}
while(t--){
memset(c0,0,sizeof(c0));
memset(c1,0,sizeof(c1));
s1=s2=0;
scanf("%d%d",&n,&m);
for(register int i=1;i<=n;i++){
pa[i]=i,vis[i]=0;
}
int u,v,p=0;
for(register int i=1;i<=n;++i){
scanf("%d",&a[i]);
if(a[i]==1) s1++;
else s2++;
}
int xx,yy;
for(register int i=1;i<=m;++i){
scanf("%d%d",&xx,&yy);
int u=find(xx),v=find(yy);
if(u!=v){
unite(xx,yy);
Edge x;
x.u=xx,x.w=i;
G[yy].push_back(x);
x.u=yy;
G[xx].push_back(x);
}
}
int rt;
for(register int i=1;i<=n;i++){
if(G[i].size()>0){
dfs(i);
rt=i;
break;
}
}
memset(vis,0,sizeof(vis));
ans=0;
Dfs(rt);
printf("%lld\n",ans);
for(int i=1;i<=n;i++) G[i].clear();
}
}