思路:用f[state]表示在没有fat brother影响的情况下,maze获胜的概率。 则 f[state] = 平局的概率 * f[state] + sigma(pi * f[子状态]),f[state] = sigma(pi * f[子状态]) / (1 - 平局的概率)
用dp[state] 表示在有fat brother的情况下,maze获胜的概率。 每一步枚举 fat brother的决策,然后取最大值即可。
用到两个辅助数组,mask[state][i]表示state这个集合里的人都出 i 的概率。 msk[state1][state2]表示state1个人的集合中,出拳的种类为state2的概率(包括maze),msk数组的作用是用于计算平局的概率。 这题卡时挺紧,没用上msk优化之前超时了很多次。
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
double f[1 << 15],dp[1 << 15];
double a[25][5];
double mask[1 << 15][5];
double msk[1 << 15][8];
int rev[1 << 15];
int lowbit(int x)
{
return x & (-x);
}
void solve()
{
for(int i = 0; i < 15; i ++)
rev[1 << i] = i;
int n;
scanf("%d",&n);
for(int i = 0; i < 3; i ++) scanf("%lf",&a[n][i]);
a[n][3] = a[n][0];
a[n][4] = a[n][1];
for(int i = 0; i < n; i ++) {
for(int j = 0; j < 3; j ++)
scanf("%lf",&a[i][j]);
a[i][3] = a[i][0];
a[i][4] = a[i][1];
}
mask[0][0] = mask[0][1] = mask[0][2] = mask[0][3] = mask[0][4] = 1;
for(int i = 1; i < (1 << n); i ++) {
int x = rev[lowbit(i)];
for(int k = 0; k < 3; k ++)
mask[i][k] = mask[i ^ lowbit(i)][k] * a[x][k];
mask[i][3] = mask[i][0];
mask[i][4] = mask[i][1];
}
for(int i = 0; i < (1 << n); i ++)
for(int j = 0; j < 8; j ++)
msk[i][j] = 0;
for(int i = 0; i < 3; i ++)
msk[0][1 << i] = a[n][i];
for(int i = 1; i < (1 << n); i ++) {
int x = rev[lowbit(i)];
for(int k = 0; k < 8; k ++) {
for(int j = 0; j < 3; j ++) {
msk[i][k | (1 << j)] += msk[i ^ (1 << x)][k] * a[x][j];
}
}
}
f[0] = 1;
for(int i = 1; i < (1 << n); i ++) {
double tmp, p = 0;
double tot = 0;
for(int j = i; j; j = (j - 1) & i) {
tmp = mask[j][0] * a[n][1] * mask[i ^ j][1] + mask[j][1] * a[n][2] * mask[i ^ j][2] + mask[j][2] * a[n][3] * mask[i ^ j][3];
tot += tmp * f[i ^ j];
//p += tmp;
//p += mask[j][1] * a[n][0] * mask[i ^ j][0] + mask[j][2] * a[n][1] * mask[i ^ j][1] + mask[j][3] * a[n][2] * mask[i ^ j][2];
}
p = msk[i][7] + msk[i][1] + msk[i][2] + msk[i][4];
f[i] = tot / (1 - p);
}
dp[0] = max(a[n][1] / (1 - a[n][0]),a[n][2] / (1 - a[n][1]));
dp[0] = max(dp[0],a[n][3] / (1 - a[n][2]));
for(int i = 1; i < (1 << n); i ++) {
double maxn = 0;
for(int j = 1; j <= 3; j ++) {
double p = 0,tot = 0;
for(int k = i; ; k = (k - 1) & i) {
if(k) {
tot += mask[k][j - 1] * a[n][j] * mask[k ^ i][j] * dp[k ^ i];
//p += mask[k][j - 1] * a[n][j] * mask[k ^ i][j];
//p += mask[k][j + 1] * a[n][j] * mask[k ^ i][j];
}
tot += mask[k][j] * a[n][j + 1] * mask[k ^ i][j + 1] * f[k ^ i];
//p += mask[k][j] * a[n][j + 1] * mask[k ^ i][j + 1];
//p += mask[k][j] * a[n][j - 1] * mask[k ^ i][j - 1];
if(!k) break;
}
for(int k = 0; k < 8; k ++) {
int x = k | (1 << (j == 3 ? 0 : j));
if(x == 7 || x == 1 || x == 2 || x == 4) p += msk[i][k];
}
maxn = max(maxn,tot / (1 - p));
}
dp[i] = maxn;
}
printf("%.10lf\n",dp[(1 << n) - 1]);
}
int main()
{
int t;
scanf("%d",&t);
while(t --) {
solve();
}
return 0;
}