题目大意:
- n n n 个猎人编号为 1 , 2 , ⋯   , n 1, 2, \cdots, n 1,2,⋯,n,依次按逆时针方向排成一个环。
- 第一枪由你打响,你会向第 ( k − 1 )   m o d   n + 1 ( k > 0 ) (k - 1) \bmod n + 1 (k > 0) (k−1)modn+1(k>0) 号猎人开枪,这个被击中的猎人有 1 2 \frac 1 2 21 的概率会死亡。所有被击中的猎人(无论死活),都会继续向他的逆时针方向开始的第 k k k 个(从他自己的下一个开始数)活着的猎人开枪。
- 当只剩一个人时游戏停止,无论最后射向他的子弹是否会打死他。
- 作为编号为 1 1 1 的猎人,想知道自己活到最后的概率。
算法分析:
- 我们很容易想到从终止状态倒推。
- 为了方便我们从 0 0 0 开始编号。
- 可以设计状态 f ( i , j ) f(i,j) f(i,j) 表示还剩 i i i 个人,其中 0 0 0 号在其中,并且编号为 0 0 0,上一枪射中了 j j j, 0 0 0 号能存活到最后的概率。
- 答案即为 f ( n , ( k − 1 )   m o d   n ) f(n,(k-1)\bmod n) f(n,(k−1)modn)。
- 不难根据被射中的这个人是否被打死了,然后转移:
f ( i , j ) = { 1 2 ( f ( i , ( i + k )   m o d   i ) + f ( i − 1 , ( i + k − 1 )   m o d   i ) ) ( i ≠ 0 ) 1 2 f ( i , ( i + k )   m o d   i ) ( i = 0 ) f(i,j)= \begin{cases} \frac{1}{2} \Big(f\big(i,(i+k)\bmod i\big)+f\big(i-1,(i+k-1)\bmod i\big)\Big) & (i\neq 0)\\ \frac{1}{2} ~~f\big(i,(i+k)\bmod i\big) & (i=0) \end{cases} f(i,j)={21(f(i,(i+k)modi)+f(i−1,(i+k−1)modi))21 f(i,(i+k)modi)(i̸=0)(i=0)
- 这里的转移形成了分层图, i − 1 i-1 i−1 到 i i i 可以直接转移,但是 i i i 层之内的转移形成了若干个环,如果直接用高斯消元,时间复杂度为 O ( n 4 ) O(n^4) O(n4)。
- 但是我们发现对于环的形式的转移,可以 O ( n ) \mathcal O(n) O(n) 实现。
- 对于一个环,我们考虑将所有的数表示为 k i x + b i k_ix+b_i kix+bi 的形式,其中 x x x 是环中的指定一个值,然后 k i , b i k_i,b_i ki,bi 都是常数,这样我们可以利用这些表达式,把 x x x 也表达成 k ′ x + b ′ k'x+b' k′x+b′,解方程 x = k ′ x + b ′ x=k'x+b' x=k′x+b′ 即可。
- 所以时间复杂度优化到 O ( n 2 ) \mathcal O(n^2) O(n2)。
#include <bits/stdc++.h>
const int mod = 1e9 + 7;
const int MaxN = 2e3 + 5;
const int inv2 = mod + 1 >> 1;
int n, K;
bool vis[MaxN];
int a[MaxN], b[MaxN], t[MaxN];
int f[MaxN][MaxN], nxt[MaxN];
inline int qpow(int x, int y)
{
int res = 1;
for (; y; y >>= 1, x = 1LL * x * x % mod)
if (y & 1) res = 1LL * res * x % mod;
return res;
}
int main()
{
freopen("hunter.in", "r", stdin);
freopen("hunter.out", "w", stdout);
std::cin >> n >> K;
f[1][0] = 1;
for (int i = 2; i <= n; ++i)
{
for (int j = 0; j < i; ++j)
{
vis[j] = false;
nxt[(j + K) % i] = j;
t[j] = j == 0 ? 0 : 1LL * inv2 * f[i - 1][(j + K - 1) % (i - 1)] % mod;
}
for (int st = 0; st < i; ++st)
{
if (vis[st]) continue;
int u = st;
a[st] = 1, b[st] = 0;
std::vector<int> S(0);
for (; !vis[u]; u = nxt[u])
{
vis[u] = true;
int v = nxt[u];
if (v != st) S.push_back(v);
a[v] = 1LL * a[u] * inv2 % mod;
b[v] = (1LL * b[u] * inv2 + t[v]) % mod;
}
f[i][st] = 1LL * (mod - b[st]) * qpow(a[st] - 1, mod - 2) % mod;
for (auto v : S) f[i][v] = (1LL * a[v] * f[i][st] + b[v]) % mod;
}
}
printf("%d\n", f[n][(K - 1) % n]);
return 0;
}