测试地址:求和
做法:本题需要用到第二类斯特林数+NTT。
从题目中给的递推式或者根据组合数学的知识,第二类斯特林数
S(i,j)
S
(
i
,
j
)
的组合意义是:将
i
i
个有区别的球放入个无区别的盒子的方案数。由此我们可以得到通项公式:
S(i,j)=1j!∑jk=0(−1)kCkj(j−k)i
S
(
i
,
j
)
=
1
j
!
∑
k
=
0
j
(
−
1
)
k
C
j
k
(
j
−
k
)
i
这其实就是一个容斥的形式,相当于枚举强制哪些盒子是空的,至于要乘一个
1j!
1
j
!
是因为盒子无区别,而里面算的方案是有区别的。
那么将这个式子代入题目要求的式子,有:
f(n)=∑ni=0∑ij=02j∑jk=0(−1)kCkj(j−k)i
f
(
n
)
=
∑
i
=
0
n
∑
j
=
0
i
2
j
∑
k
=
0
j
(
−
1
)
k
C
j
k
(
j
−
k
)
i
将组合数拆开,整理得:
f(n)=∑ni=0∑ij=02jj!∑jk=0(−1)kk!⋅(j−k)i(j−k)!
f
(
n
)
=
∑
i
=
0
n
∑
j
=
0
i
2
j
j
!
∑
k
=
0
j
(
−
1
)
k
k
!
⋅
(
j
−
k
)
i
(
j
−
k
)
!
我们发现后半部分已经很像一个卷积的形式了,但是因为它还和
i
i
有关,所以我们想办法把换进去。
我们知道当
j>i
j
>
i
时
S(i,j)=0
S
(
i
,
j
)
=
0
,所以上式中
j
j
的上限可以换成,那么就可以把
i
i
换进去,得到:
那么这个式子的后半部分就是函数
g(x)=(−1)xx!
g
(
x
)
=
(
−
1
)
x
x
!
和函数
h(x)=∑ni=0xix!
h
(
x
)
=
∑
i
=
0
n
x
i
x
!
的卷积了,可以用NTT求出,而求
h(x)
h
(
x
)
时,我们发现它是一个等比数列的前缀和,直接用等比数列求和公式求即可。特别地,
h(0)=1,h(1)=n+1
h
(
0
)
=
1
,
h
(
1
)
=
n
+
1
,直接用公式算的话这两个会算错。
以下是本人代码:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod=998244353;
const ll g=3;
ll n,fac[100010],inv[100010],invfac[100010];
ll a[1000010]={0},b[1000010]={0};
int r[1000010];
ll power(ll a,ll b)
{
ll s=1,ss=a;
while(b)
{
if (b&1) s=s*ss%mod;
ss=ss*ss%mod;b>>=1;
}
return s;
}
void NTT(ll *a,ll type,int n)
{
for(int i=0;i<n;i++)
if (i<r[i]) swap(a[i],a[r[i]]);
for(int mid=1;mid<n;mid<<=1)
{
ll W=power(g,(mod-1)/(mid<<1));
if (type==-1) W=power(W,mod-2);
for(int l=0;l<n;l+=(mid<<1))
{
ll w=1;
for(int k=0;k<mid;k++,w=w*W%mod)
{
ll x=a[l+k],y=w*a[l+mid+k]%mod;
a[l+k]=(x+y)%mod;
a[l+mid+k]=(x-y+mod)%mod;
}
}
}
if (type==-1)
{
ll inv=power(n,mod-2);
for(int i=0;i<n;i++)
a[i]=a[i]*inv%mod;
}
}
int main()
{
scanf("%lld",&n);
fac[0]=fac[1]=inv[1]=invfac[0]=invfac[1]=1;
for(ll i=2;i<=n;i++)
{
fac[i]=fac[i-1]*i%mod;
inv[i]=(mod-mod/i)*inv[mod%i]%mod;
invfac[i]=invfac[i-1]*inv[i]%mod;
}
for(ll i=0;i<=n;i++)
{
a[i]=(((i%2)?-1:1)*invfac[i]+mod)%mod;
if (i==0) b[i]=1;
if (i==1) b[i]=n+1;
if (i>1) b[i]=(power(i,n+1)-1+mod)*invfac[i]%mod*inv[i-1]%mod;
}
int x=1,bit=0;
while(x<=(n<<2)) x<<=1,bit++;
r[0]=0;
for(int i=1;i<x;i++)
r[i]=(r[i>>1]>>1)|((i&1)<<(bit-1));
NTT(a,1,x),NTT(b,1,x);
for(int i=0;i<x;i++)
a[i]=a[i]*b[i]%mod;
NTT(a,-1,x);
ll ans=0;
for(ll i=0,j=1;i<=n;i++,j=j*2ll%mod)
{
ll tmp=j*fac[i]%mod;
tmp=tmp*a[i]%mod;
ans=(ans+tmp)%mod;
}
printf("%lld",ans);
return 0;
}