题目
思路
将水域从下往上将行编号、从左往右将列编号。
怎样处理恰好为 k k k
其实很简单:容斥。说白了就是 ≤ k \le k ≤k的情况,减去 < k <k <k的情况。于是下面都只处理至多为 k k k的情况。
怎样处理至多为 k k k
也很神奇。只需要用 f ( x , j ) f(x,j) f(x,j)表示 宽度为 x x x、最下面的一个危险区域高度不小于 j j j,该 x × 1001 x\times1001 x×1001的矩形满足条件,其概率是多少。当然,我们要忽略海岸线(最下面一行)有危险水域的情况——否则状态数太多。
此时就有了一个范围: 0 ≤ x ≤ k , 2 ≤ j 0\le x\le k,2\le j 0≤x≤k,2≤j
为什么只限定宽度呢?因为每一个水域都是一样的。繁多是个谎言
至于最下面的一个危险区域的高度有什么用呢?答案:用于转移。
不妨设第 t t t列、第 j j j行是一个危险区域。那么,倘若矩形包含第 t t t列,一定是 1 1 1到 j − 1 j-1 j−1行、 1 1 1到 x x x列组成的矩形最大;否则,我们就成功的转化成了独立的子问题:两边都符合条件。
第 j j j行中的危险区域可能有很多——我们应当枚举 最左边的那一个。(当然,最右边的那一个也可以)
这样一来,方案就一定是不重复、不遗漏的。
在本次转移中,我们应当计算第 t t t列的概率——它没有被蓝色的方框(别的 f f f计算)。
所以状态转移就是 g ( i , j ) − g ( i , j + 1 ) = q j − 1 ( 1 − q ) ∑ t = 1 i g ( t − 1 , j + 1 ) g ( i − t , j ) g(i,j)-g(i,j+1)=q^{j-1}(1-q)\sum_{t=1}^{i}g(t-1,j+1)g(i-t,j) g(i,j)−g(i,j+1)=qj−1(1−q)t=1∑ig(t−1,j+1)g(i−t,j)
注意到有用的 g ( i , j ) g(i,j) g(i,j)满足 i ( j − 1 ) ≤ k i(j-1)\le k i(j−1)≤k,也就是 O ( k ln k ) \mathcal O(k\ln k) O(klnk)个。时间复杂度 O ( k 2 ln k ) \mathcal O(k^2\ln k) O(k2lnk).
怎样处理海岸线
不妨用 A ( x ) A(x) A(x)表示,处理了 x x x列,第 x x x列是危险水域。(只考虑最下面一行)
那么答案就是 A ( n + 1 ) ( 1 − q ) \frac{A(n+1)}{(1-q)} (1−q)A(n+1),这肯定难不倒机智的你。
转移就可以枚举前一个危险水域,也就是 A ( x ) = ( 1 − q ) ∑ i = 1 k + 1 A ( x − i ) g ( i − 1 , 2 ) A(x)=(1-q)\sum_{i=1}^{k+1}A(x-i)g(i-1,2) A(x)=(1−q)i=1∑k+1A(x−i)g(i−1,2)
令 B ( x ) = ( 1 − q ) g ( x − 1 , 2 ) B(x)=(1-q)g(x-1,2) B(x)=(1−q)g(x−1,2),那么可以写成 A ( x ) = ∑ i = 1 k + 1 A ( x − i ) B ( i ) A(x)=\sum_{i=1}^{k+1}A(x-i)B(i) A(x)=i=1∑k+1A(x−i)B(i)
常系数齐次线性递推!这玩意儿可以优化到 O ( k 2 log k ) \mathcal O(k^2\log k) O(k2logk)。
代码
#include <cstdio>
#include <iostream>
#include <vector>
#include <algorithm>
#include <cstring>
using namespace std;
inline int readint(){
int a = 0; char c = getchar(), f = 1;
for(; c<'0' or c>'9'; c=getchar())
if(c == '-') f = -f;
for(; '0'<=c and c<='9'; c=getchar())
a = (a<<3)+(a<<1)+(c^48);
return a*f;
}
inline int qkpow(long long base,int q,int Mod){
int ans = 1; base %= Mod;
for(; q; q>>=1,base=base*base%Mod)
if(q&1) ans = base*ans%Mod;
return ans;
}
const int MaxN = 1005, Mod = 998244353;
typedef vector<int> poly;
poly operator * (const poly &a,const poly &b){
poly c; c.resize(a.size()+b.size()-1);
for(int i=0,lena=a.size(); i<lena; ++i)
for(int j=0,lenb=b.size(); j<lenb; ++j)
c[i+j] = (c[i+j]+1ll*a[i]*b[j])%Mod;
return c;
}
poly &operator %= (poly &a,const poly &b){
for(int i=int(a.size())-1,lenb=b.size(); i>=lenb-1; --i){
int f = a[i]/b[lenb-1];
for(int j=lenb; j; --j){
a[i+j-lenb] -= 1ll*b[j-1]*f%Mod;
a[i+j-lenb] = (a[i+j-lenb]+Mod)%Mod;
}
}
a.resize(b.size()-1);
return a;
}
inline poly qkpow(poly f,int q,poly Mod){
poly ans; ans.push_back(1);
for(; q; q>>=1,f=f*f,f%=Mod)
if(q&1) ans = ans*f, ans %= Mod;
return ans;
}
int g[MaxN][MaxN], n, q, powQ[MaxN<<1];
int a[MaxN]; /* 系数 */ int h[MaxN<<1];
int work(int k){
for(int i=2; i<MaxN; ++i)
g[0][i] = 1;
long long p = (1ll+Mod-q)%Mod; // p = 1-q
long long invP = qkpow(p,Mod-2,Mod); // 1/(1-q)
for(int i=1; i<=k; ++i){
g[i][k/i+2] = 0;
for(int j=k/i+1; j>=2; --j){
int &x = g[i][j] = 0;
for(int t=1; t<=i; ++t)
x = (x+1ll*g[t-1][j+1]*g[i-t][j])%Mod;
x= (p*x%Mod*powQ[j-1]+g[i][j+1])%Mod;
}
}
++ k; // 这样就不用一直写k+1了
for(int i=1; i<=k; ++i)
a[i] = p*g[i-1][2]%Mod;
h[0] = 1; // 边界条件
for(int i=1; i<=(k<<1); ++i){
int &t = h[i] = 0;
for(int j=1; j<=i and j<=k; ++j)
t = (t+1ll*h[i-j]*a[j])%Mod;
}
if(n <= (k<<1)) return h[n]*invP%Mod;
// n在读入的时候就已经加了一
poly f; f.resize(k+1);
for(int i=0; i<k; ++i)
f[i] = (Mod-a[k-i])%Mod;
f[k] = 1;
poly only_x; only_x.resize(2);
only_x[0] = 0, only_x[1] = 1;
f = qkpow(only_x,n-k,f);
int res = 0;
for(int i=0; i<k; ++i)
res = (res+1ll*f[i]*h[i+k])%Mod;
return res*invP%Mod;
}
int main(){
int k; // no need to be 全局变量
n = readint()+1, k = readint(), q = readint();
q = 1ll*q*qkpow(readint(),Mod-2,Mod)%Mod;
for(int i=(powQ[0]=1); i<=(k<<1); ++i)
powQ[i] = 1ll*powQ[i-1]*q%Mod;
printf("%d\n",(work(k)+Mod-work(k-1))%Mod);
return 0;
}