题目链接:http://acm.zju.edu.cn/onlinejudge/showProblem.do?problemId=5987
题意:一共有n个城堡,如果给一个城堡i装备武器,那么其攻击值为w[i],否则为1,所有城堡攻击值总量为每一个城堡攻击值的乘积,现在给m个互斥条件,每个互斥条件表示x和y两个城堡不可同时装备武器。假设一共有k种装备方案,装备方案i的所有城堡攻击值为xi,则定义方案的平均值 和方差
为
=
,
=
题目要求方差的值,以A/B的逆元形式输出
思路:首先化简方差公式为
=
设Sk = x1+x2+x3+...+xk, Tk =
在暴力dfs的基础上,考虑现在为第i个城堡,要不要装备:
1.若在i+1~n之间有和第i个城堡互斥的城堡存在,那么就直接暴力第i个城堡装备和不装备两种情况,如果装备,就将后面所有和i互斥的城堡vis[i]++,使得后面不会遍历到与其相斥的城堡
2.但很显然如果全都暴力,会t,所以考虑如果i+1~n之间没有城堡会和第i个互斥,那么即第i个取和不取不影响后面的情况计算,那么Si = Si+Si*w[i], Ti = Ti+Ti*w[i]*w[i],然后情况数量乘2就可以继续遍历下一项了,这样就是O(n)做法
所以可以看到,如果没有互斥限制,那么就是O(n^2)的复杂度,如果有限制条件,最大的情况就是20个不同的互斥的情况,那么复杂度是O(*n)
代码:
#include<iostream>
#include <cstdio>
#include<cstring>
#include<algorithm>
#include<map>
#include<functional>
#include <unordered_map>
#include<queue>
#include<cmath>
#include<unordered_map>
using namespace std;
typedef long long ll;
const ll mod = 1e9 + 7;
int vis[50];
ll extendGcd(ll a, ll b, ll &x, ll &y) {
ll ans, t;
if (b == 0) {
x = 1; y = 0;
return a;
}
ans = extendGcd(b, a%b, x, y);
t = x; x = y; y = t - (a / b)*y;
return ans;
}
ll inv(ll a, ll m) {
ll x, y, d;
d = extendGcd(a, m, x, y);
if (d == 1)
return (x%m + m) % m;
else
return -1;
}
int mm[50][50];
ll w[50];
ll res = 1, sum = 1;
ll ans = 0, k = 0;
int n, m;
void dfs(int id, ll s, ll t, ll kk)
{
if (id > n)
{
res += t;
res %= mod;
sum += s;
sum %= mod;
k += kk;
return;
}
for (int i = id + 1; i <= n + 1; i++)
{
if (vis[i])continue;
vis[i] = 1;
int f = 1;
for (int j = i + 1; j <= n; j++)
{
if (mm[i][j])
{
vis[j]++;
f = 0;
}
}
if (f && i<=n)
{
dfs(i, (s+s*w[i]%mod)%mod, (t+t*w[i]%mod*w[i]%mod)%mod, kk*2);
vis[i] = 0;
break;
}
if (i <= n)
{
dfs(i, s*w[i] % mod, t*w[i] % mod*w[i] % mod, kk);
for (int j = i + 1; j <= n; j++)
{
if (mm[i][j])
vis[j]--;
}
}
dfs(i, s, t, kk);
vis[i] = 0;
break;
}
}
int main()
{
scanf("%d%d", &n, &m);
memset(mm, 0, sizeof(mm));
for (int i = 1; i <= n; i++)
{
scanf("%lld", &w[i]);
w[i] %= mod;
}
for (int i = 0; i < m; i++)
{
int x, y;
scanf("%d%d", &x, &y);
mm[x][y] = mm[y][x] = 1;
}
for (int i = 1; i <= n; i++)
{
memset(vis, 0, sizeof(vis));
vis[i] = 1;
for (int j = i + 1; j <= n; j++)
{
if (mm[i][j])
{
vis[j] = 1;
}
}
dfs(i, w[i], w[i]*w[i]%mod, 1);
}
k++;
ans = (res*inv(k, mod) % mod - sum * sum%mod*inv(k, mod) % mod*inv(k, mod) % mod + mod) % mod;
printf("%lld\n", ans);
}