题目大意:
有一个
1001×n
1001
×
n
的的网格,每个格子有
q
q
的概率是安全的,的概率是危险的。
定义一个矩形是合法的当且仅当:
1.这个矩形中每个格子都是安全的
2.必须紧贴网格的下边界
问你最大的合法子矩形大小恰好为
k
k
的概率是多少。
解题思路:
首先求恰好为的概率一般转化为求 ≤k ≤ k 的概率减去 ≤k−1 ≤ k − 1 的概率。
如何求 ≤k ≤ k 的概率的概率呢?
设
fi,j
f
i
,
j
表示一段长度为
j
j
的海滩,前行默认没有障碍(前
i
i
行概率当1算),第行有障碍,最大子矩阵不超过
k
k
的概率,那么我们顺次枚举第行危险的格子,得到转移方程有
fi,j=∑t=1j∑k≥i∑l>i(fk,t−1q(k−i)(t−1))(1−q)(fl,j−tq(l−i)(l−i))
f
i
,
j
=
∑
t
=
1
j
∑
k
≥
i
∑
l
>
i
(
f
k
,
t
−
1
q
(
k
−
i
)
(
t
−
1
)
)
(
1
−
q
)
(
f
l
,
j
−
t
q
(
l
−
i
)
(
l
−
i
)
)
也就是将左右两部分拼起来,中间定一个危险的格子限制高度,令
i∗j≤k
i
∗
j
≤
k
,再用前缀和优化(
Fi,j=∑k≥ifk,j
F
i
,
j
=
∑
k
≥
i
f
k
,
j
),复杂度为
O(k2)
O
(
k
2
)
统计答案时设
ansi
a
n
s
i
表示前
i
i
列合法的概率,则有:
ansi=∑j=1k+1ansi−j(1−q)(F1,j−1qj−1),i>k
a
n
s
i
=
∑
j
=
1
k
+
1
a
n
s
i
−
j
(
1
−
q
)
(
F
1
,
j
−
1
q
j
−
1
)
,
i
>
k
也就是限制每段安全长度不超过
k
k
。
注意到如果我们计算出,后面的 ansi=∑k+1j=1ajansi−j a n s i = ∑ j = 1 k + 1 a j a n s i − j 是常系数齐次递推式,可以利用特征多项式优化矩阵快速幂至 O(k2logk) O ( k 2 l o g k ) ,这样就做完了。
#include<bits/stdc++.h>
#define ll long long
using namespace std;
int getint()
{
int i=0,f=1;char c;
for(c=getchar();(c!='-')&&(c<'0'||c>'9');c=getchar());
if(c=='-')c=getchar(),f=-1;
for(;c>='0'&&c<='9';c=getchar())i=(i<<3)+(i<<1)+c-'0';
return i*f;
}
const int N=2005,mod=998244353;
int n;
ll Pow_q[N],f[N][N],g[N],a[N],b[N],c[N],ans[N],q,p;
int Pow(ll x,int y)
{
ll res=1;
for(;y;y>>=1,x=x*x%mod)
if(y&1)res=res*x%mod;
return res;
}
void mul(ll *a,ll *b,int k)
{
for(int i=0;i<=2*k;i++)c[i]=0;
for(int i=0;i<=k;i++)
for(int j=0;j<=k;j++)
c[i+j]=(c[i+j]+a[i]*b[j])%mod;
for(int i=2*k;i>=k+1;i--)
{
for(int j=0;j<=k;j++)
c[i-k-1+j]=(c[i-k-1+j]+c[i]*g[k+1-j])%mod;
c[i]=0;
}
for(int i=0;i<=k;i++)a[i]=c[i];
}
void Pow(ll *a,int y,ll *b,int k)
{
b[0]=1;
for(;y;y>>=1,mul(a,a,k))
if(y&1)mul(b,a,k);
}
int solve(int k)
{
if(!k)return Pow(p,n);
memset(f,0,sizeof(f));
f[k+1][0]=1;
for(int i=k;i;i--)
{
f[i][0]=1;
int m=min(n,k/i);
for(int j=0;j<=m;j++)g[j]=f[i+1][j]*Pow_q[j]%mod*p%mod;
for(int j=1;j<=m;j++)
{
for(int t=0;t<j;t++)
f[i][j]=(f[i][j]+g[t]*f[i][j-t-1])%mod;
f[i][j]=(f[i][j]+f[i+1][j]*Pow_q[j])%mod;
}
}
for(int i=1;i<=k+1;i++)g[i]=f[1][i-1]*Pow_q[i-1]%mod*p%mod;
ans[0]=1;
for(int i=1;i<=k;i++)
{
ans[i]=f[1][i]*Pow_q[i]%mod;
for(int j=1;j<=i;j++)ans[i]=(ans[i]+g[j]*ans[i-j])%mod;
}
if(n<=k)return ans[n];
memset(a,0,sizeof(a)),memset(b,0,sizeof(b));
a[1]=1;Pow(a,n,b,k);
ll res=0;
for(int i=0;i<=k;i++)res=(res+b[i]*ans[i])%mod;
return (res+mod)%mod;
}
int main()
{
//freopen("lx.in","r",stdin);
n=getint();int k=getint();ll x=getint(),y=getint();
q=x*Pow(y,mod-2)%mod,p=(1-q+mod)%mod;
Pow_q[0]=1;
for(int i=1;i<=k;i++)Pow_q[i]=Pow_q[i-1]*q%mod;
printf("%d\n",(solve(k)-solve(k-1)+mod)%mod);
}