题解:
先推一波公式。
这是一个容斥原理的式子:
S(n,m)=1m!∑mk=0(−1)kCkm(m−k)n
S
(
n
,
m
)
=
1
m
!
∑
k
=
0
m
(
−
1
)
k
C
m
k
(
m
−
k
)
n
就是枚举有多少个盒子是空的,容斥一下。由于盒子是一样的,所以最后要除以m!。
=>
S(n,m)=1m!∑mk=0(−1)km!k!(m−k)!(m−k)n
S
(
n
,
m
)
=
1
m
!
∑
k
=
0
m
(
−
1
)
k
m
!
k
!
(
m
−
k
)
!
(
m
−
k
)
n
=>
S(n,m)=∑mk=0(−1)k1k!(m−k)!(m−k)n
S
(
n
,
m
)
=
∑
k
=
0
m
(
−
1
)
k
1
k
!
(
m
−
k
)
!
(
m
−
k
)
n
=>
S(n,m)=∑mk=0(−1)kk!(m−k)n(m−k)!
S
(
n
,
m
)
=
∑
k
=
0
m
(
−
1
)
k
k
!
(
m
−
k
)
n
(
m
−
k
)
!
我们再看过来要求的式子。
f(n)=∑ni=0∑ij=0S(i,j)2jj!
f
(
n
)
=
∑
i
=
0
n
∑
j
=
0
i
S
(
i
,
j
)
2
j
j
!
=>
f(n)=∑ni=0∑nj=0S(i,j)2jj!
f
(
n
)
=
∑
i
=
0
n
∑
j
=
0
n
S
(
i
,
j
)
2
j
j
!
=>
f(n)=∑nj=02jj!∑ni=0S(i,j)
f
(
n
)
=
∑
j
=
0
n
2
j
j
!
∑
i
=
0
n
S
(
i
,
j
)
=>
f(n)=∑nj=02jj!∑ni=0∑jk=0(−1)kk!(j−k)i(j−k)!
f
(
n
)
=
∑
j
=
0
n
2
j
j
!
∑
i
=
0
n
∑
k
=
0
j
(
−
1
)
k
k
!
(
j
−
k
)
i
(
j
−
k
)
!
=>
f(n)=∑nj=02jj!∑jk=0(−1)kk!∑ni=0(j−k)i(j−k)!
f
(
n
)
=
∑
j
=
0
n
2
j
j
!
∑
k
=
0
j
(
−
1
)
k
k
!
∑
i
=
0
n
(
j
−
k
)
i
(
j
−
k
)
!
于是我们可以预处理
∑jk=0(−1)kk!
∑
k
=
0
j
(
−
1
)
k
k
!
和
∑ni=0(j−k)i(j−k)!
∑
i
=
0
n
(
j
−
k
)
i
(
j
−
k
)
!
的值,然后就可以NTT卷积求解啦!
代码
#include<cstdio>
#include<algorithm>
using namespace std;
typedef long long ll;
const int N=270005;
const ll mod=998244353;
int n,m,rev[N];
ll ans,jc[N],a[N],b[N];
ll fastpow(ll a,ll x){
ll res=1;
while(x){
if(x&1){
res=res*a%mod;
}
x>>=1;
a=a*a%mod;
}
return res;
}
ll getinv(ll x){
return fastpow(x,mod-2);
}
void ntt(ll *a,int dft){
for(int i=0;i<n;i++){
if(i<rev[i]){
swap(a[i],a[rev[i]]);
}
}
for(int i=1;i<n;i<<=1){
ll wn=fastpow(3,(mod-1)/i/2);
if(dft==-1){
wn=getinv(wn);
}
for(int j=0;j<n;j+=i<<1){
ll w=1,x,y;
for(int k=j;k<j+i;k++,w=w*wn%mod){
x=a[k];
y=w*a[k+i]%mod;
a[k]=(x+y)%mod;
a[k+i]=(x-y+mod)%mod;
}
}
}
if(dft==-1){
ll inv=getinv(n);
for(int i=0;i<n;i++){
a[i]=a[i]*inv%mod;
}
}
}
int main(){
scanf("%d",&m);
jc[0]=1;
for(int i=1;i<=m;i++){
jc[i]=jc[i-1]*i%mod;
}
for(int i=0;i<=m;i++){
a[i]=(fastpow(-1,i)*getinv(jc[i])+mod)%mod;
}
b[0]=1;
b[1]=m+1;
for(int i=2;i<=m;i++){
b[i]=(fastpow(i,m+1)-1)*getinv(i-1)%mod*getinv(jc[i])%mod;
}
for(n=1;n<=m*2;n<<=1);
for(int i=0;i<n;i++){
rev[i]=(rev[i>>1]>>1)|((i&1)*(n>>1));
}
ntt(a,1);
ntt(b,1);
for(int i=0;i<n;i++){
a[i]=a[i]*b[i]%mod;
}
ntt(a,-1);
for(int i=0;i<=m;i++){
ans=(ans+fastpow(2,i)*jc[i]%mod*a[i]%mod)%mod;
}
printf("%lld\n",ans);
return 0;