Address
Solution
首先,一个显然的 DP 状态:
f
[
i
]
f[i]
f[i] 表示第一个数当前为
i
i
i ,将其变成
0
0
0 的期望步数。
边界当然是
f
[
0
]
=
0
f[0]=0
f[0]=0 。
讨论一波转移:
设
P
(
i
,
x
)
P(i,x)
P(i,x) 表示当第一个数为
i
i
i 时,
k
k
k 轮减操作让第一个数减少
x
x
x 的概率。
这样转移就很显然了:
当
i
<
n
i<n
i<n 时:
f
[
i
]
=
1
+
1
m
+
1
∑
j
=
0
i
+
1
P
(
i
+
1
,
j
)
×
f
[
i
−
j
+
1
]
+
m
m
+
1
∑
j
=
0
i
P
(
i
,
j
)
×
f
[
i
−
j
]
f[i]=1+\frac 1{m+1}\sum_{j=0}^{i+1}P(i+1,j)\times f[i-j+1]+\frac m{m+1}\sum_{j=0}^iP(i,j)\times f[i-j]
f[i]=1+m+11j=0∑i+1P(i+1,j)×f[i−j+1]+m+1mj=0∑iP(i,j)×f[i−j]
当
i
=
n
i=n
i=n 时:
f
[
i
]
=
1
+
∑
j
=
0
i
P
(
i
,
j
)
×
f
[
i
−
j
]
f[i]=1+\sum_{j=0}^iP(i,j)\times f[i-j]
f[i]=1+j=0∑iP(i,j)×f[i−j]
要解决两个小问题:
(1)
P
(
i
,
j
)
P(i,j)
P(i,j) 的值。
分下类: 当
j
<
i
j<i
j<i 时,相当于在
k
k
k 次操作中选出
j
j
j 次操作对第一个数进行,剩下的
k
−
j
k-j
k−j 次操作对剩下的
m
m
m 个数进行。
所以:
P
(
i
,
j
)
=
{
C
k
j
×
m
k
−
j
(
m
+
1
)
k
j
<
i
1
−
∑
k
=
0
i
−
1
P
(
i
,
k
)
j
=
i
P(i,j)=\begin{cases}\frac{C_k^j\times m^{k-j}}{(m+1)^k}&j<i\\1-\sum_{k=0}^{i-1}P(i,k)&j=i\end{cases}
P(i,j)={(m+1)kCkj×mk−j1−∑k=0i−1P(i,k)j<ij=i
特别地,如果
k
<
j
k<j
k<j 则
P
(
i
,
j
)
=
0
P(i,j)=0
P(i,j)=0 。
(2) 转移的后效性。
把每个
f
[
i
]
f[i]
f[i] 当作一个未知变量,使用高斯消元解方程。
但这样复杂度是
O
(
T
n
3
)
O(Tn^3)
O(Tn3) 的。
发现系数矩阵长这个样子:
[
X
0
0
0
0
0
0
…
0
X
X
X
0
0
0
0
…
0
X
X
X
X
0
0
0
…
0
X
X
X
X
X
0
0
…
0
X
X
X
X
X
X
0
…
0
X
X
X
X
X
X
X
…
0
⋮
⋮
⋮
⋮
⋮
⋮
⋮
⋱
⋮
X
X
X
X
X
X
X
X
X
X
X
X
X
X
X
X
X
X
]
\begin{bmatrix}X&0&0&0&0&0&0&\dots&0\\X&X&X&0&0&0&0&\dots&0\\X&X&X&X&0&0&0&\dots&0\\X&X&X&X&X&0&0&\dots&0\\X&X&X&X&X&X&0&\dots&0\\X&X&X&X&X&X&X&\dots&0\\\vdots&\vdots&\vdots&\vdots&\vdots&\vdots&\vdots&\ddots&\vdots\\X&X&X&X&X&X&X&X&X\\X&X&X&X&X&X&X&X&X\end{bmatrix}
⎣⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎡XXXXXX⋮XX0XXXXX⋮XX0XXXXX⋮XX00XXXX⋮XX000XXX⋮XX0000XX⋮XX00000X⋮XX………………⋱XX000000⋮XX⎦⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎤
从第一列到第
n
+
1
n+1
n+1 列分别表示
f
[
0
]
f[0]
f[0] 到
f
[
n
]
f[n]
f[n] ,第一行到第
n
+
1
n+1
n+1 行分别表示
f
[
0
]
f[0]
f[0] 和
f
[
n
]
f[n]
f[n] 的转移。
这矩阵已经非常接近于下三角矩阵。
我们只需要从最后一行开始网上,对于第
i
i
i (
i
>
2
i>2
i>2 )行,只需要用第
i
i
i 行去消第
i
i
i 行使得第
i
i
i 行第
i
+
1
i+1
i+1 列为
0
0
0 即可。
这样系数矩阵就变成了下三角矩阵,从
f
[
0
]
f[0]
f[0] 开始一一代入即可。
注:如果出现了除以
0
0
0 的情况则方程组无解,输出
−
1
-1
−1 。
时间复杂度
O
(
T
n
2
)
O(Tn^2)
O(Tn2) 。
Code
#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#define For(i, a, b) for (i = a; i <= b; i++)
#define Rof(i, a, b) for (i = a; i >= b; i--)
inline int read()
{
int res = 0; bool bo = 0; char c;
while (((c = getchar()) < '0' || c > '9') && c != '-');
if (c == '-') bo = 1; else res = c - 48;
while ((c = getchar()) >= '0' && c <= '9')
res = (res << 3) + (res << 1) + (c - 48);
return bo ? ~res + 1 : res;
}
template <class T>
T Min(T a, T b) {return a < b ? a : b;}
const int N = 1505, ZZQ = 1e9 + 7;
int n, p, m, k, inv[N], f[N][N], pw[N], C[N], a[N];
int qpow(int a, int b)
{
int res = 1;
while (b)
{
if (b & 1) res = 1ll * res * a % ZZQ;
a = 1ll * a * a % ZZQ;
b >>= 1;
}
return res;
}
void work()
{
int i, j, alls, orz, tmp, rp;
n = read(); p = read(); m = read(); k = read();
orz = qpow(m + 1, ZZQ - 2);
alls = qpow(qpow(m + 1, k), ZZQ - 2);
C[0] = 1;
For (i, 1, n) C[i] = 1ll * C[i - 1] * inv[i] % ZZQ * (k - i + 1) % ZZQ;
For (i, 0, Min(n, k))
pw[i] = 1ll * qpow(m, k - i) * C[i] % ZZQ * alls % ZZQ;
f[0][0] = 1; f[0][n + 1] = 0;
For (i, 1, n) f[0][i] = 0;
For (i, 1, n)
{
For (j, 0, n + 1) f[i][j] = 0;
f[i][n + 1] = 1; f[i][i] = rp = 1;
For (j, 0, i)
{
tmp = j <= k ? 1ll * pw[j] * (i == n ? 0 : orz) % ZZQ : 0;
f[i][i - j + 1] -= tmp; rp -= tmp;
if (f[i][i - j + 1] < 0) f[i][i - j + 1] += ZZQ;
if (rp < 0) rp += ZZQ;
}
if (i < n) f[i][0] -= 1ll * rp * orz % ZZQ;
if (f[i][0] < 0) f[i][0] += ZZQ;
rp = 1;
For (j, 0, i - 1)
{
tmp = j <= k ? 1ll * pw[j]
* (i == n ? 1 : 1ll * m * orz % ZZQ) % ZZQ : 0;
f[i][i - j] -= tmp; rp -= tmp;
if (f[i][i - j] < 0) f[i][i - j] += ZZQ;
if (rp < 0) rp += ZZQ;
}
f[i][0] -= i == n ? rp : 1ll * rp * m % ZZQ * orz % ZZQ;
if (f[i][0] < 0) f[i][0] += ZZQ;
}
Rof (i, n, 2)
{
if (!f[i][i]) return (void) puts("-1");
int tmp = qpow(f[i][i], ZZQ - 2);
For (j, 0, n + 1) f[i][j] = 1ll * f[i][j] * tmp % ZZQ;
tmp = f[i - 1][i];
For (j, 0, n + 1)
{
f[i - 1][j] -= 1ll * f[i][j] * tmp % ZZQ;
if (f[i - 1][j] < 0) f[i - 1][j] += ZZQ;
}
}
if (!f[1][1]) return (void) puts("-1");
tmp = qpow(f[1][1], ZZQ - 2);
For (i, 0, n + 1) f[1][i] = 1ll * f[1][i] * tmp % ZZQ;
a[0] = 0;
For (i, 1, p)
{
a[i] = f[i][n + 1];
For (j, 0, i - 1)
{
a[i] -= 1ll * f[i][j] * a[j] % ZZQ;
if (a[i] < 0) a[i] += ZZQ;
}
}
printf("%d\n", a[p]);
}
int main()
{
int i, T = read();
inv[1] = 1;
For (i, 2, 1500) inv[i] = 1ll * (ZZQ - ZZQ / i) * inv[ZZQ % i] % ZZQ;
while (T--) work();
return 0;
}