传送门:ZOJ4098
题意:给出n个数,有m个互斥关系,可以获得所有的排列组合关系,并且在这些组合中不能同时存在互斥的一对数,每一种可能的组合有一个value=排列中所有数的乘积,题目要求我们输出这些值的方差。拿样例来说,给出三个数1、2、3,一组互斥关系1、2表示第一个和第二个数在一个组合中不能同时存在,因此我们可以获得的所有组合关系有{}、{1}、{2}、{3}、{1 3}、{2 3},因为{1 2}、{1 2 3}中存在互斥的一对数,因此不能作为value。因此我们总共可以得到六个value{1 1 2 3 3 6}(空集value为1)按照题目给出的方差公式即可求出答案。
解析:
观察数据范围发现,n为40且m为0时,所有排列的数量已经达到了2^n个,因此无法通过求出每个值的方法来计算答案。因此我们需要对公式进行化简。我们对方差公式进行化简发现:
因此问题可以转化为求出所有value的和以及平方和,考虑到n最大为40,因此我们采用折半搜索求解。
下面解释一下为什么折半搜索可以得到正确答案:
我们将所有数尽可能均匀地分为两堆数,在两堆数中不存在跨堆的互斥关系(比赛没考虑到这一点),折半搜索记录的是所有存在的组合关系的和sum1,sum2以及平方和summ1,summ2,因此我们要求的总和即为sum1+sum2+sum1sum2,总平方和为summ1+summ2+summ1summ2,为什么是这个数呢?因为我们在折半搜索过程中得到的所有排列关系是所求排列关系的子关系,对结果相乘即为我们要求的总关系。下面拿样例来具体说明,我们可以将所有的数分为{1 2}和{3}两堆,对这两堆求得的排列关系为{1} {2}以及第二堆的{3},我们将两个集合相乘可以得到{1 3} {2 3},加上原来的子排列可以得到{1}{2}{3}{1 3}{2 3},加上一个空集即我们所要求的答案。为什么不能跨堆存在互斥关系呢? 这一点经过上面的分析已经很明显了,两个集合相乘的结果就是集合元素的合并,因此如果跨堆存在互斥关系,合并后的集合也有可能存在互斥关系,因此无法得到正确答案。
在得到求解答案的方法后,我们需要思考如何才能让两个堆中的数尽可能平均,从而提高效率呢?我们考虑对数进行建图,发现一个联通块内的数是有可能存在互斥关系的,因此我们可以记录所有连通块内的元素个数,将每个连通块内的元素逐一放入我们要进行搜索的堆中,尽可能的平分所有数,这里可以考虑DP或者贪心。
最后要提醒的是,所有的结果因为存在空集都要+1
附上AC代码(我觉得除了开扑腾没人想看吧(:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod = 1e9 + 7;
const int maxn = 1e6;
ll n, m;
ll a[55];
int mp[55][55];
vector<int> v, vv, w[55];
int cc;
int tu[55][55];
int va[55], vaa[55];
int tot, tott;
bool vis[55];
bool viss[55];
ll sum, summ;
ll su, suu;
ll cnt, cntt;
int po[55];
int to;
ll qpow(ll a, ll b, ll p) {
ll ret = 1;
while (b) {
if (b & 1) ret = (ret * a) % p;
a = (a * a) % p;
b >>= 1;
}
return ret;
}
ll ny(ll a) {
return qpow(a, mod - 2, mod);
}
void dfs(int in) {
if (in == v.size()) {
if (tot == 0) return;
ll tmp = 1;
ll tmpp = 1;
for (int i = 0; i < tot; i++) {
tmp *= (a[va[i]] * a[va[i]]) % mod;
tmp %= mod;
tmpp *= a[va[i]];
tmpp %= mod;
}
sum += tmp;
sum %= mod;
su += tmpp;
su %= mod;
cnt++;
return;
}
for (int i = 0; i < tot; i++) {
if (mp[va[i]][v[in]] == 1) {
dfs(in + 1);
return;
}
}
vis[v[in]] = 1;
va[tot++] = v[in];
dfs(in + 1);
vis[v[in]] = 0;
tot--;
dfs(in + 1);
}
void dfss(int in) {
if (in == vv.size()) {
if (tott == 0) return;
ll tmp = 1;
ll tmpp = 1;
for (int i = 0; i < tott; i++) {
tmp *= (a[vaa[i]] * a[vaa[i]]) % mod;
tmp %= mod;
tmpp *= a[vaa[i]];
tmpp %= mod;
}
summ += tmp;
summ %= mod;
suu += tmpp;
suu %= mod;
cntt++;
return;
}
for (int i = 0; i < tott; i++) {
if (mp[vaa[i]][vv[in]] == 1) {
dfss(in + 1);
return;
}
}
vis[vv[in]] = 1;
vaa[tott++] = vv[in];
dfss(in + 1);
vis[vv[in]] = 0;
tott--;
dfss(in + 1);
}
void df(int x) {
vis[x] = 1;
po[to++] = x;
for (int i = 1; i <= n; i++) {
if (vis[i] == 0 && tu[x][i] == 1) df(i);
}
}
int cmp(vector<int> a, vector<int> b) {
return a.size() < b.size();
}
int main() {
scanf("%lld%lld", &n, &m);
for (int i = 1; i <= n; i++) {
scanf("%lld", &a[i]);
a[i] %= mod;
}
while (m--) {
int x, y;
scanf("%d%d", &x, &y);
mp[x][y] = 1;
mp[y][x] = 1;
tu[x][y] = 1;
tu[y][x] = 1;
}
for (int i = 0; i < 55; i++) vis[i] = 0;
for (int i = 1; i <= n; i++) {
if (vis[i] == 0) {
to = 0;
df(i);
for (int j = 0; j < to; j++) {
w[cc].push_back(po[j]);
}
cc++;
}
}
sort(w, w + cc, cmp);
for (int i = cc - 1; i >= 0; i--) {
if (v.size() + w[i].size() > (n + 1) / 2) break;
viss[i] = 1;
for (int j = 0; j < w[i].size(); j++) v.push_back(w[i][j]);
}
for (int i = 0; i < cc; i++) {
if (v.size() + w[i].size() > (n + 1) / 2) break;
viss[i] = 1;
for (int j = 0; j < w[i].size(); j++) v.push_back(w[i][j]);
}
for (int i = 0; i < cc; i++) {
if (viss[i] == 0) {
for (int j = 0; j < w[i].size(); j++) vv.push_back(w[i][j]);
}
}
for (int i = 0; i < 55; i++) vis[i] = 0;
dfs(0);
for (int i = 0; i < 55; i++) vis[i] = 0;
dfss(0);
ll k = ny(((cnt + cntt) % mod + (cnt * cntt) % mod + 1) % mod) % mod;
ll ba = ((su + suu) % mod + (su * suu) % mod + 1) % mod;
ba *= k;
ba %= mod;
ll ans = ((sum + summ) % mod + (sum * summ) % mod + 1) % mod;
ans *= k;
ans %= mod;
ans -= (ba * ba) % mod;
while (ans < 0) ans += mod;
printf("%lld\n", ans);
return 0;
}