[HDU-6832] A Very Easy Graph Problem【贪心】【搜索】
题意
对于给定的无向图,计算:
∑
i
=
1
n
∑
j
=
1
n
d
(
i
,
j
)
×
[
a
i
=
1
∧
a
j
=
0
]
\sum_{i = 1}^n \sum_{j = 1}^n d(i, j) \times [a_i = 1 \land a_j = 0]
i=1∑nj=1∑nd(i,j)×[ai=1∧aj=0]
思路
对于第 i i i 条边,如果 u , v u, v u,v 在连这条边前就已经联通,那么此时 u , v u, v u,v 的距离之和最大是 ∑ j = 1 i − 1 2 j \sum_{j = 1}^{i - 1} 2^j ∑j=1i−12j ,也就是 2 i − 1 2^i - 1 2i−1 ,而这是比第 i i i 条边的权重 2 i 2^i 2i 要小的。所以第 i i i 条边能够被选作最短路径的情况当前仅当 u , v u, v u,v 不连通。于是我们按照上面的方法选边,最终会选出一颗以最短路径组成的树。
接下来考虑每条边的贡献问题,设总的0、1点数量分别为 c n t 0 , c n t 1 cnt_0, cnt_1 cnt0,cnt1 ,对于一个子树,我们选择该子树的根节点 u u u,再选择其下的子树(以根节点的儿子节点 v v v 为根的子树),设这个子子树有 a 0 a_0 a0 个0, a 1 a_1 a1 个1,则从根节点 u u u 向每一个子节点 v v v 拓展出的边,其贡献为 v a l ( u , v ) ∗ ( ( c n t 0 − a 0 ) ∗ a 1 ) + ( c n t 1 − a 1 ) ∗ a 0 ) val(u, v) * ((cnt_0 - a_0) * a_1) + (cnt_1 - a_1) * a_0) val(u,v)∗((cnt0−a0)∗a1)+(cnt1−a1)∗a0) ,其中 v a l ( u , v ) val(u, v) val(u,v) 是这条边的权重。
综上,并查集维护连通状态,再dfs同时计数即可。
代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int mod = 1e9 + 7;
const int N = 2e5 + 10;
struct edge
{
int v;
ll w;
};
vector<edge> vec[N];
int type[N];
int sz[N][2];
int par[N], rnk[N];
int cnt[N];
ll pow2[N];
void init(int n)
{
cnt[0] = cnt[1] = 0;
for(int i = 0; i <= n; i++)
{
par[i] = i;
rnk[i] = 0;
vec[i].clear();
sz[i][0] = sz[i][1] = 0;
}
}
int find(int x)
{
if(x == par[x])
return x;
return par[x] = find(par[x]);
}
void unite(int x, int y)
{
x = find(x);
y = find(y);
if(x == y)
return ;
if(rnk[x] > rnk[y])
{
par[y] = x;
}
else
{
par[x] = y;
if(rnk[x] == rnk[y])
rnk[y]++;
}
}
ll dfs(int now, int fa)
{
ll ans = 0;
sz[now][0] = 0, sz[now][1] = 0;
sz[now][type[now]] = 1;
for(auto to : vec[now])
{
if(to.v == fa)
continue;
ans = (ans + dfs(to.v, now)) % mod;
int v = to.v;
sz[now][0] += sz[v][0];
sz[now][1] += sz[v][1];
}
for(auto to : vec[now])
{
if(to.v == fa)
continue;
ll num = (sz[to.v][0] * (cnt[1] - sz[to.v][1]) % mod + sz[to.v][1] * (cnt[0] - sz[to.v][0]) % mod) % mod;
ans = (ans + to.w * num % mod) % mod;
}
return ans;
}
int main()
{
pow2[0] = 1;
for(int i = 1; i <= (int)2e5; i++)
pow2[i] = pow2[i - 1] * 2LL % mod;
int T;
cin >> T;
while(T--)
{
int n, m;
cin >> n >> m;
init(n);
for(int i = 1; i <= n; i++)
{
cin >> type[i];
cnt[type[i]]++;
}
for(int i = 1, u, v; i <= m; i++)
{
cin >> u >> v;
ll w = pow2[i];
if(find(u) != find(v))
{
unite(u, v);
vec[u].push_back({v, w});
vec[v].push_back({u, w});
}
}
cout << dfs(1, -1) << endl;
}
return 0;
}