题意:
袋子里有 w w w 只白鼠和 b b b 只黑鼠 , A A A 和 B B B 每次轮流抓,谁先抓到白色谁就赢。 B B B 每次随机抓完一只之后会有另一只随机老鼠跑出来。如果两个人都没有抓到白色则 B B B赢。 A A A 先抓,问 A A A 赢的概率。
题目分析:
很显然的一道概率 d p dp dp……
首先定义 d p [ i ] [ j ] {dp[i][j]} dp[i][j] 为 轮到A抓老鼠且剩余 i i i 只白鼠, j j j 只黑鼠时的获胜概率。
那么情况有两种:
-
A A A 在抓到了白鼠(直接获胜)
概率: p 1 = i / ( i + j ) {p1 = i / (i + j)} p1=i/(i+j);
解释: ( i + j ) (i + j) (i+j) 为此时老鼠的总数量
-
A A A在这一局没有抓到白鼠
第一种情况很好转移,所以重点在于第二种情况,思考一下可以发现——为了让 A A A 获胜,这一局 B B B 就不能获胜(抓到黑鼠)。
在第二种情况下, B B B 抓到黑鼠后的情况有两种 :
-
B B B抓到黑鼠,逃走的是白鼠
概率: p 2 = j / ( i + j ) ∗ ( j − 1 ) / ( i + j − 1 ) ∗ i / ( j + i − 2 ) ∗ d p [ i − 1 ] [ j − 2 ] {p2 = j / (i + j) * (j - 1) / (i + j - 1) * i / (j + i - 2) * dp[i - 1][j - 2]} p2=j/(i+j)∗(j−1)/(i+j−1)∗i/(j+i−2)∗dp[i−1][j−2];
解释: j / ( i + j ) j / (i + j) j/(i+j) , A A A抓到黑鼠的概率。
( j − 1 ) / ( i + j − 1 ) (j - 1) / (i + j - 1) (j−1)/(i+j−1) , B B B抓到黑鼠的概率。
i / ( j + i − 2 ) i / (j + i - 2) i/(j+i−2) 逃走的是白鼠的概率。
d p [ i − 1 ] [ j − 2 ] dp[i - 1][j - 2] dp[i−1][j−2] 为 A A A在剩余 i − 1 i - 1 i−1只白鼠, j − 2 j - 2 j−2只黑鼠时 A A A获胜的概率 -
B B B抓到黑鼠,逃走的是黑鼠
概率: p 3 = j / ( i + j ) ∗ ( j − 1 ) / ( i + j − 1 ) ∗ ( j − 2 ) / ( i + j − 2 ) ∗ d p [ i ] [ j − 3 ] {p3 = j / (i + j) * (j - 1) / (i + j - 1) * (j - 2) / (i + j - 2) * dp[i][j - 3]} p3=j/(i+j)∗(j−1)/(i+j−1)∗(j−2)/(i+j−2)∗dp[i][j−3];
解释:参考前文
综上所述:白鼠在这一句获胜的概率为三种情况概率的总和
d p [ i ] [ j ] = p 1 + p 2 + p 3 {dp[i][j] = p1 + p2 + p3} dp[i][j]=p1+p2+p3
所以状态转移方程就是:
dp[i][j] += (double) i / (i + j);
if(j >= 2) dp[i][j] += (double) j / (i + j) * (j - 1) / (i + j - 1) * i / (j + i - 2) * dp[i - 1][j - 2];
if(j >= 3) dp[i][j] += (double) j / (i + j) * (j - 1) / (i + j - 1) * (j - 2) / (i + j - 2) * dp[i][j - 3];
完整代码:
#include<cstdio>
#include<algorithm>
using namespace std;
const int MAXN = 1005;
int w,b;
double dp[MAXN][MAXN];//轮到A抓老鼠且剩余i只白鼠,j只黑鼠时的获胜概率
int main() {
scanf("%d %d",&w,&b);
for(int i = 1; i <= w; i ++) dp[i][0] = 1; //当全是白鼠时,A一定获胜,概率为 1
for(int i = 0; i <= b; i ++) dp[0][i] = 0; //当没有白鼠时,B获胜,A获胜概率为 0
for(int i = 1; i <= w; i ++) {
for(int j = 1; j <= b; j ++) {
dp[i][j] += (double) i / (i + j);
if(j >= 2) dp[i][j] += (double) j / (i + j) * (j - 1) / (i + j - 1) * i / (j + i - 2) * dp[i - 1][j - 2];
if(j >= 3) dp[i][j] += (double) j / (i + j) * (j - 1) / (i + j - 1) * (j - 2) / (i + j - 2) * dp[i][j - 3];
}
}
printf("%.9lf",dp[w][b]);
return 0;
}