(传说中的
n
2
k
n^2k
n2k做法)
首先强制鸽子饱的顺序为
1
−
n
1-n
1−n,最后答案乘
n
!
n!
n!即可
我们只需要考虑喂一次喂到了未饱的鸽子的情况,我们称之为有效喂食
下一次喂食为有效喂食的概率为
n
−
x
n
\frac{n-x}{n}
nn−x,其中x为已经饱了的鸽子数
所以两次有效喂食之间的无效喂食次数的期望为
n
n
−
x
\frac{n}{n-x}
n−xn
这样我们就消除了无效喂食的影响了
设
g
[
i
]
[
j
]
g[i][j]
g[i][j]表示已经进行了
i
i
i次有效喂食,
j
j
j只鸽子已经饱了的概率,
f
[
i
]
[
j
]
f[i][j]
f[i][j]表示期望
转移有两种情况:
1.下一次喂食没有喂饱一只鸽子,这时概率就直接乘上
1
n
−
x
\frac{1}{n-x}
n−x1,期望加上当前概率乘上无效喂食次数的期望
2.下一次喂食喂饱了一只鸽子,这时我们要从喂饱上一个鸽子之后的所有有效喂食中选出k-1个出来作为喂这只将要饱的鸽子的喂食,即
C
i
−
j
∗
k
k
−
1
C_{i-j*k}^{k-1}
Ci−j∗kk−1,转移的时候就用情况1的转移乘上这个系数即可
Code:
#include<bits/stdc++.h>
#define mod 998244353
using namespace std;
inline int read(){
int res=0,f=1;char ch=getchar();
while(!isdigit(ch)) {if(ch=='-') f=-f;ch=getchar();}
while(isdigit(ch)) {res=(res<<1)+(res<<3)+(ch^48);ch=getchar();}
return res*f;
}
const int N=105,K=5005;
inline int add(int x,int y){x+=y;if(x>=mod) x-=mod;return x;}
inline int dec(int x,int y){x-=y;if(x<0) x+=mod;return x;}
inline int mul(int x,int y){return 1ll*x*y%mod;}
inline void inc(int &x,int y){x+=y;if(x>=mod) x-=mod;}
inline void Dec(int &x,int y){x-=y;if(x<0) x+=mod;}
inline void Mul(int &x,int y){x=1ll*x*y%mod;}
inline int ksm(int a,int b){int res=1;for(;b;b>>=1,a=mul(a,a)) if(b&1) res=mul(res,a);return res;}
int fac[N*K],ifac[N*K];
int inv[N],p[N],e[N];
inline void init(int n,int k){
fac[0]=fac[1]=ifac[0]=ifac[1]=1;
for(int i=2;i<=n*k;i++) fac[i]=mul(fac[i-1],i);
ifac[n*k]=ksm(fac[n*k],mod-2);
for(int i=n*k-1;i;i--) ifac[i]=mul(ifac[i+1],i+1);
inv[1]=1;
for(int i=2;i<=n;i++) inv[i]=mul((mod-mod/i),inv[mod%i]);
for(int i=0;i<=n;i++) p[i]=inv[n-i],e[i]=mul(n,inv[n-i]);
}
inline int C(int n,int m){if(n<0 || m<0 || n<m) return 0;return mul(fac[n],mul(ifac[m],ifac[n-m]));}
int f[N*K][N],g[N*K][N];
int main(){
int n=read(),k=read();init(n,k);
f[0][0]=0,g[0][0]=1;
for(int i=0;i<=n*k;i++)
for(int j=0;j<=i/k;j++) if(g[i][j]){
int P=mul(g[i][j],p[j]),E=add(mul(P,e[j]),mul(p[j],f[i][j])),Com=C(i-j*k,k-1);
inc(f[i+1][j],E);
inc(g[i+1][j],P);
inc(f[i+1][j+1],mul(E,Com));
inc(g[i+1][j+1],mul(P,Com));
}
cout<<mul(f[n*k][n],fac[n]);
return 0;
}