首先有递推式:
s
(
i
,
j
)
=
s
(
i
−
1
,
j
−
1
)
+
(
i
−
1
)
∗
s
(
i
−
1
,
j
)
s(i,j)=s(i-1,j-1)+(i-1)*s(i-1,j)
s(i,j)=s(i−1,j−1)+(i−1)∗s(i−1,j)
为方便卷积写成这样(第二维和为
j
j
j):
s
(
i
,
j
)
=
s
(
i
−
1
,
j
−
1
)
∗
b
(
i
,
1
)
+
b
(
i
,
0
)
∗
s
(
i
−
1
,
j
)
s(i,j)=s(i-1,j-1)*b(i,1)+b(i,0)*s(i-1,j)
s(i,j)=s(i−1,j−1)∗b(i,1)+b(i,0)∗s(i−1,j)
其中
b
(
i
,
1
)
=
1
,
b
(
i
,
0
)
=
i
−
1
b(i,1)=1,b(i,0)=i-1
b(i,1)=1,b(i,0)=i−1
那么把
s
(
i
)
s(i)
s(i) 看成一个多项式,
s
(
i
,
j
)
s(i,j)
s(i,j) 为这个多项式
x
j
x^j
xj 项的系数,初值:
s
(
0
,
0
)
=
1
s(0,0)=1
s(0,0)=1
b
(
i
)
b(i)
b(i) 同理
那么
s
(
i
)
=
s
(
i
−
1
)
∗
b
(
i
)
s(i)=s(i-1)*b(i)
s(i)=s(i−1)∗b(i)
于是把
s
(
0
)
s(0)
s(0) ~
s
(
n
)
s(n)
s(n) 都乘起来,得到的多项式就是
s
(
n
)
s(n)
s(n)
这个多项式的
x
i
x^i
xi 项的系数就是
s
(
n
,
i
)
s(n,i)
s(n,i)
分治
n
t
t
ntt
ntt 即可,时间复杂度
O
(
n
log
2
n
)
O(n \log^2n)
O(nlog2n)
code
#include<bits/stdc++.h>usingnamespace std;#define ll long longconstint e =1e6+5, mod =998244353;int n, a1, b1, fac[e], inv[e], rev[e], lim;
vector<int>g[e];inlineintksm(int x,int y){int res =1;while(y){if(y &1) res =(ll)res * x % mod;
y >>=1;
x =(ll)x * x % mod;}return res;}inlinevoidupt(int&x,int y){
x = y;if(x >= mod) x -= mod;}inlinevoidfft(int n,int*a,int opt){int i, j, k, r =(opt ==1?3:(mod +1)/3);for(i =0; i < n; i++)if(i < rev[i])swap(a[i], a[rev[i]]);for(k =1; k < n; k <<=1){int w0 =ksm(r,(mod -1)/(k <<1));for(i =0; i < n; i +=(k <<1)){int w =1;for(j =0; j < k; j++){int b = a[i + j], c =(ll)w * a[i + j + k]% mod;upt(a[i + j], b + c);upt(a[i + j + k], b + mod - c);
w =(ll)w * w0 % mod;}}}}inlinevoidsolve(int l,int r){if(l >= r)return;int i, mid = l + r >>1;solve(l, mid);solve(mid +1, r);staticint a[266666], b[266666], c[266666];int k =0, la = g[l].size(), lb = g[mid +1].size();
lim =1;while(lim < la + lb -1){
lim <<=1;
k++;}for(i =0; i < lim; i++){
a[i]= b[i]=0;
rev[i]=(rev[i >>1]>>1)|((i &1)<< k -1);}for(i =0; i < la; i++) a[i]= g[l][i];for(i =0; i < lb; i++) b[i]= g[mid +1][i];fft(lim, a,1);fft(lim, b,1);for(i =0; i < lim; i++) a[i]=(ll)a[i]* b[i]% mod;fft(lim, a,-1);int tot =ksm(lim, mod -2);for(i =0; i < lim; i++) a[i]=(ll)a[i]* tot % mod;
g[l].clear();for(i =0; i < la + lb -1; i++) g[l].push_back(a[i]);}inlineintc(int x,int y){if(x < y)return0;return(ll)fac[x]* inv[y]% mod * inv[x - y]% mod;}intmain(){int i;
cin >> n >> a1 >> b1;
fac[0]=1;for(i =1; i <= n; i++) fac[i]=(ll)fac[i -1]* i % mod;
inv[n]=ksm(fac[n], mod -2);for(i = n -1; i >=0; i--) inv[i]=(ll)inv[i +1]*(i +1)% mod;int res =c(a1 + b1 -2, a1 -1);
g[0].push_back(1);for(i =1; i <= n; i++){
g[i].push_back(i -1);
g[i].push_back(1);}solve(0, n -1);if(a1 + b1 -2< g[0].size()) res =(ll)res * g[0][a1 + b1 -2]% mod;else res =0;
cout << res << endl;return0;}