Description
给你一个n*m的方块,每天靠边都有p的概率被吹掉,一共连续吹k天,问最后每一层都有方块且互相连通的概率,以逆元形式输出。
Sample Input
2 2
1 2
1
Sample Output
937500007
首先考虑设
h
[
i
]
[
l
]
[
r
]
h[i][l][r]
h[i][l][r]为第i行l~r的块都在,其他吹掉,前面合法的概率。
因为每一层是单独分开的吗,你可以预处理出一个pl,pr分别表示1 ~ l-1被吹掉,l保留的概率,r+1 ~ m被吹掉,r保留的概率,就可以算出一个区间没被吹的概率。
只要枚举一个范围转移即可,时间复杂度
O
(
n
m
4
)
O(nm^4)
O(nm4)。
考虑引入前缀和优化,设:
f
[
i
]
[
r
]
=
∑
l
=
1
r
h
[
i
]
[
l
]
[
r
]
f[i][r]=\sum_{l=1}^rh[i][l][r]
f[i][r]=∑l=1rh[i][l][r]
S
l
[
i
]
[
r
]
=
∑
l
=
1
r
f
[
i
]
[
l
]
Sl[i][r]=\sum_{l=1}^rf[i][l]
Sl[i][r]=∑l=1rf[i][l]
S
r
Sr
Sr根据相当于
S
l
Sl
Sl反过来。
那么每次转移的时候我们多维护一个这个,考虑用总数减去不合法的,可得DP方程:
h
[
i
]
[
l
]
[
r
]
=
p
l
[
l
]
∗
p
r
[
r
]
∗
(
S
l
[
i
−
1
]
[
m
]
−
S
l
[
i
−
1
]
[
l
−
1
]
−
S
r
[
i
−
1
]
[
r
+
1
]
)
h[i][l][r]=pl[l]*pr[r]*(Sl[i-1][m]-Sl[i-1][l-1]-Sr[i-1][r+1])
h[i][l][r]=pl[l]∗pr[r]∗(Sl[i−1][m]−Sl[i−1][l−1]−Sr[i−1][r+1])
这样就能做到
O
(
n
m
2
)
O(nm^2)
O(nm2)。
考虑其实可以直接维护
f
f
f。
f
[
i
]
[
r
]
=
∑
l
=
1
r
h
[
i
]
[
l
]
[
r
]
f[i][r]=\sum_{l=1}^rh[i][l][r]
f[i][r]=∑l=1rh[i][l][r]
=
∑
l
=
1
r
p
l
[
l
]
∗
p
r
[
r
]
∗
(
S
l
[
i
−
1
]
[
m
]
−
S
l
[
i
−
1
]
[
l
−
1
]
−
S
r
[
i
−
1
]
[
r
+
1
]
)
=\sum_{l=1}^rpl[l]*pr[r]*(Sl[i-1][m]-Sl[i-1][l-1]-Sr[i-1][r+1])
=∑l=1rpl[l]∗pr[r]∗(Sl[i−1][m]−Sl[i−1][l−1]−Sr[i−1][r+1])
=
∑
l
=
1
r
p
l
[
l
]
∗
p
r
[
r
]
∗
S
l
[
i
−
1
]
[
m
]
−
p
l
[
l
]
∗
p
r
[
r
]
∗
S
l
[
i
−
1
]
[
l
−
1
]
−
p
l
[
l
]
∗
p
r
[
r
]
∗
S
r
[
i
−
1
]
[
r
+
1
]
=\sum_{l=1}^rpl[l]*pr[r]*Sl[i-1][m]-pl[l]*pr[r]*Sl[i-1][l-1]-pl[l]*pr[r]*Sr[i-1][r+1]
=∑l=1rpl[l]∗pr[r]∗Sl[i−1][m]−pl[l]∗pr[r]∗Sl[i−1][l−1]−pl[l]∗pr[r]∗Sr[i−1][r+1]
于是你可以处理一个
∑
l
=
1
r
p
l
[
l
]
\sum_{l=1}^rpl[l]
∑l=1rpl[l],
∑
l
=
1
r
p
l
[
l
]
∗
S
[
i
−
1
]
[
l
−
1
]
\sum_{l=1}^rpl[l]*S[i-1][l-1]
∑l=1rpl[l]∗S[i−1][l−1]。
于是转移就可以做到
O
(
n
m
)
O(nm)
O(nm)
由于你前后好像没什么关系,于是我只开了一维。
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long LL;
const LL mod = 1e9 + 7;
int _min(int x, int y) {return x < y ? x : y;}
int _max(int x, int y) {return x > y ? x : y;}
int read() {
int s = 0, f = 1; char ch = getchar();
while(ch < '0' || ch > '9') {if(ch == '-') f = -1; ch = getchar();}
while(ch >= '0' && ch <= '9') s = s * 10 + ch - '0', ch = getchar();
return s * f;
}
LL jc[110000], inv[110000];
LL f[1600], suml[110000], sumr[110000], gl[1600];
LL pl[1600], pr[1600], S[1600];
LL pow_mod(LL a, LL k) {
LL ans = 1;
while(k) {
if(k & 1) (ans *= a) %= mod;
(a *= a) %= mod; k /= 2;
} return ans;
}
LL C(int n, int m) {return jc[n] * inv[m] % mod * inv[n - m] % mod;}
int main() {
int n = read(), m = read();
int a = read(), b = read();
int k = read();
LL P = (LL)a * pow_mod(b, mod - 2) % mod;
suml[0] = 1LL; for(int i = 1; i <= k; i++) suml[i] = suml[i - 1] * P % mod;
P = (LL)(b - a) * pow_mod(b, mod - 2) % mod;
sumr[0] = 1LL; for(int i = 1; i <= k; i++) sumr[i] = sumr[i - 1] * P % mod;
jc[0] = inv[0] = 1; for(int i = 1; i <= k; i++) jc[i] = (LL)jc[i - 1] * i % mod;
inv[k] = pow_mod(jc[k], mod - 2);
for(int i = k - 1; i >= 1; i--) inv[i] = (LL)inv[i + 1] * (i + 1) % mod;
for(int i = 1; i <= m; i++) {
if(i > k + 1) pl[i] = 0;
else pl[i] = (suml[i - 1] * sumr[k - i + 1] % mod) * C(k, i - 1) % mod;
pr[m - i + 1] = pl[i];
} S[0] = 0; for(int i = 1; i <= m; i++) S[i] = (S[i - 1] + pl[i]) % mod;
for(int i = 1; i <= m; i++) f[i] = S[i] * pr[i] % mod;
suml[0] = sumr[0] = sumr[m + 1] = 0;
for(int i = 1; i <= m; i++) suml[i] = (suml[i - 1] + f[i]) % mod, sumr[m - i + 1] = suml[i];
for(int i = 1; i <= m; i++) gl[i] = (gl[i - 1] + pl[i]) % mod;
S[0] = 0; for(int i = 1; i <= m; i++) S[i] = (S[i - 1] + pl[i] * suml[i - 1] % mod) % mod;
for(int i = 2; i <= n; i++) {
for(int r = 1; r <= m; r++) {
f[r] = (gl[r] * pr[r] % mod) * (suml[m] - sumr[r + 1]) % mod;
(f[r] -= S[r] * pr[r] % mod) %= mod;
(f[r] += mod) %= mod;
} for(int r = 1; r <= m; r++) suml[r] = (suml[r - 1] + f[r]) % mod, sumr[m - r + 1] = suml[r];
for(int r = 1; r <= m; r++) S[r] = (S[r - 1] + pl[r] * suml[r - 1] % mod) % mod;
} printf("%lld\n", suml[m]);
return 0;
}