我们发现最后一个达到
k
k
k 的盒子用时可能是无穷大的,我们不方便直接求(因为我们很难写出一个式子来求极限)。
然后我们发现如果是求第一个到达
k
k
k 的盒子的用时,这个就肯定是有限的范围了。
因此想到
m
i
n
−
m
a
x
min-max
min−max 容斥:
E
(
max
t
i
)
=
E
(
∑
s
⊆
t
(
−
1
)
∣
s
∣
+
1
min
i
=
1
∣
s
∣
s
i
)
=
∑
i
=
1
n
(
−
1
)
i
+
1
⋅
(
n
i
)
⋅
T
(
i
)
\begin{aligned} &~~~~~~E(\max t_i)\\ &=E(\sum_{s\subseteq t}(-1)^{|s|+1}\min_{i=1}^{|s|} s_i)\\ &=\sum_{i=1}^{n}(-1)^{i+1}\cdot \binom{n}{i}\cdot T(i)\\ \end{aligned}
E(maxti)=E(s⊆t∑(−1)∣s∣+1i=1min∣s∣si)=i=1∑n(−1)i+1⋅(in)⋅T(i)
其中
T
(
i
)
T(i)
T(i) 表示只考虑前面
i
i
i 个盒子中,恰有一个盒子被装了
k
k
k 个球的期望步数。
发现现在的问题就是求
T
(
i
)
T(i)
T(i)。
我们可以考虑枚举步数
c
c
c,然后把算期望转化成算概率,因为总方案数已知,我们只要求合法方案数,然后算一下即可。
考虑如果一种合法方案在
i
i
i 个盒子里放了
c
c
c 个球,那么对答案的贡献就是:
Q
(
i
,
c
)
=
(
1
i
)
c
⋅
c
⋅
n
m
Q(i,c) = (\frac{1}{i})^c\cdot c\cdot \frac n m
Q(i,c)=(i1)c⋅c⋅mn
那么我们只要求出,将
c
c
c 个球放入
i
i
i 个盒子,恰好满足有且仅有一个盒子恰好达到
k
k
k 个的方案数。
因为最后一个球必须是放在某个盒子的第
k
k
k 个,这个球顺序固定,必须在最后一个。
而前面的球满足所有盒子的球数都不超过
k
−
1
k-1
k−1,前面的顺序可以互相调换(就是第
i
i
i 个放哪个箱子的顺序)。
因此我们最后考虑第
c
c
c 个球。
我们考虑一下,如果我们将
c
c
c 个球放入
i
i
i 个盒子。
其中固定了
a
j
a_j
aj 表示第
j
j
j 个盒子放
a
j
a_j
aj 个球,那么不同的顺序方案有(同一个盒子的球必须按固定顺序放):
c
!
∏
j
=
1
i
(
a
j
)
!
\frac{c!}{\prod_{j=1}^{i}(a_j)!}
∏j=1i(aj)!c!
我们只需要计算所有分配
a
j
a_j
aj 的方案的
1
∏
j
=
1
i
a
j
!
\frac{1}{\prod_{j=1}^i a_j!}
∏j=1iaj!1 的和即可。
设
f
(
i
,
c
)
f(i,c)
f(i,c) 表示在
i
i
i 个盒子里放
c
c
c 个球,每个盒子不超过
k
−
1
k-1
k−1 个,所有方案的
1
∏
j
=
1
i
a
j
!
\frac{1}{\prod_{j=1}^i a_j!}
∏j=1iaj!1 的和。
设
g
(
i
,
c
)
g(i,c)
g(i,c) 表示在
i
i
i 个盒子里放
c
c
c 个球,每个盒子不超过
k
−
1
k-1
k−1 个,并且选定一个恰好放了
k
−
1
k-1
k−1 的球的盒子作为特殊盒子(我们会将最后一个球放在这个盒子),所有方案的
1
∏
j
=
1
i
a
j
!
\frac{1}{\prod_{j=1}^i a_j!}
∏j=1iaj!1 的和。
不难得到:
f
(
i
,
c
)
=
∑
j
=
0
k
−
1
f
(
i
−
1
,
c
−
j
)
⋅
1
j
!
g
(
i
,
c
)
=
f
(
i
−
1
,
c
−
k
+
1
)
(
k
−
1
)
!
+
∑
j
=
0
k
−
1
g
(
i
−
1
,
c
−
j
)
j
!
f(i,c)=\sum_{j=0}^{k-1}f(i-1,c-j)\cdot \frac1{j!}\\ g(i,c)=\frac {f(i-1,c-k+1)}{(k-1)!}+\sum_{j=0}^{k-1}\frac{g(i-1,c-j)}{j!}
f(i,c)=j=0∑k−1f(i−1,c−j)⋅j!1g(i,c)=(k−1)!f(i−1,c−k+1)+j=0∑k−1j!g(i−1,c−j)
然后枚举最后一个球之前放了几个球,就有:
T
(
i
)
=
∑
c
=
0
n
(
k
−
1
)
g
(
i
,
c
)
⋅
c
!
⋅
Q
(
i
,
c
+
1
)
T(i)=\sum_{c=0}^{n(k-1)}g(i,c)\cdot c!\cdot Q(i,c+1)
T(i)=c=0∑n(k−1)g(i,c)⋅c!⋅Q(i,c+1)
瓶颈就是求
f
,
g
f,g
f,g 要用
O
(
n
2
k
2
)
\mathcal O(n^2k^2)
O(n2k2) 的复杂度。
用 FFT 就可以优化到
O
(
n
2
k
log
(
n
k
)
)
\mathcal O(n^2k\log(nk))
O(n2klog(nk))。
#include<bits/stdc++.h>template<classT>inlinevoidread(T &x){staticchar ch;staticbool opt;while(!isdigit(ch =getchar())&& ch !='-');
x =(opt = ch =='-')?0: ch -'0';while(isdigit(ch =getchar()))
x = x *10+ ch -'0';if(opt) x =~x +1;}template<classT>inlinevoidputint(T x){staticchar buf[15],*tail = buf;if(!x)putchar('0');else{if(x <0)putchar('-'), x =~x +1;for(; x; x /=10)*++tail = x %10+'0';for(; tail != buf;--tail)putchar(*tail);}}constint MaxN =55;constint MaxK =1e3+5;constint MaxB = MaxN * MaxK;constint mod =998244353;//g=3inlineintqpow(int x,int y){int res =1;for(; y; y >>=1, x =1LL* x * x % mod)if(y &1) res =1LL* res * x % mod;return res;}namespace polynomial
{constint MaxN =1e6+5;int L, P, rev[MaxN];inlinevoidpoly_init(int n){
P =0, L =1;while(L <= n) L <<=1,++P;for(int i =1; i < L;++i)
rev[i]= rev[i >>1]>>1|(i &1)<< P -1;}inlinevoidDFT(int*a,int n,int opt){int g = opt ==1?3:(mod +1)/3;for(int i =1; i < L;++i)if(i < rev[i]) std::swap(a[i], a[rev[i]]);for(int k =1; k < n; k <<=1){int omega =qpow(g,(mod -1)/(k <<1));for(int i =0; i < n; i += k <<1){int x =1;for(int j =0; j < k;++j){int u = a[i + j], v =1LL* x * a[i + j + k]% mod;
a[i + j]= u + v >= mod ? u + v - mod : u + v;
a[i + j + k]= u - v <0? u - v + mod : u - v;
x =1LL* x * omega % mod;}}}if(opt ==-1){int inv =qpow(n, mod -2);for(int i =0; i < n;++i)
a[i]=1LL* a[i]* inv % mod;}}inlinevoidpoly_mul(int*a,int*b,int*c,int na,int nb,int nc){poly_init(na + nb);staticint ta[MaxN], tb[MaxN];memset(ta,0,sizeof(ta));memset(tb,0,sizeof(tb));for(int i =0; i < na;++i)
ta[i]= a[i];for(int i =0; i < nb;++i)
tb[i]= b[i];DFT(ta, L,1),DFT(tb, L,1);for(int i =0; i < L;++i)
c[i]=1LL* ta[i]* tb[i]% mod;DFT(c, L,-1);for(int i = nc; i < L;++i)
c[i]=0;}}using polynomial::poly_mul;int n, K, B;int fra[MaxB], ind[MaxB];int f[MaxN][MaxB], g[MaxN][MaxB], T[MaxN];inlinevoidinit_fra(int n){
fra[0]=1;for(int i =1; i <= n;++i)
fra[i]=1LL* fra[i -1]* i % mod;
ind[n]=qpow(fra[n], mod -2);for(int i = n -1; i >=0;--i)
ind[i]=1LL* ind[i +1]*(i +1)% mod;}inlineintC(int n,int m){if(n <0|| m <0|| n < m)return0;return1LL* fra[n]* ind[m]% mod * ind[n - m]% mod;}intmain(){freopen("sanrd.in","r",stdin);freopen("sanrd.out","w",stdout);
std::cin >> n >> K;
B = n *(K -1);init_fra(n * K);
f[0][0]=1;for(int i =1; i <= n;++i){poly_mul(f[i -1], ind, f[i], B +1, K, B +1);poly_mul(g[i -1], ind, g[i], B +1, K, B +1);for(int c =0; c <= B;++c)if(c - K +1>=0)
g[i][c]=(g[i][c]+1LL* f[i -1][c - K +1]* ind[K -1])% mod;for(int c =0; c <= B;++c){int tmp =1LL* fra[c]*qpow(i,1LL*(mod -2)*(c +1)%(mod -1))% mod;
tmp =1LL* tmp * n *(c +1)% mod *qpow(i, mod -2)% mod;
T[i]=(T[i]+1LL* g[i][c]* tmp)% mod;}}int ans =0;for(int i =1; i <= n;++i){int val =1LL*C(n, i)* T[i]% mod;if(i &1)
ans =(ans + val)% mod;else
ans =(ans + mod - val)% mod;}printf("%d\n", ans);return0;}