由于 当 j > i 时 , S ( i , j ) = 0 当 j > i时,S(i,j) = 0 当j>i时,S(i,j)=0;
∴ ∑ i = 0 n ∑ j = 0 i S ( i , j ) ∗ 2 j ∗ ( j ! ) \sum_{i = 0}^n\sum_{j = 0}^iS(i,j) * 2^j *(j!) ∑i=0n∑j=0iS(i,j)∗2j∗(j!) = ∑ i = 0 n ∑ j = 0 n S ( i , j ) ∗ 2 j ∗ ( j ! ) \sum_{i = 0}^n\sum_{j = 0}^nS(i,j) * 2^j *(j!) ∑i=0n∑j=0nS(i,j)∗2j∗(j!)
第二类斯特林数 S ( i , j ) S(i,j) S(i,j)我们有通项公式: 1 j ! ∑ k = 0 j c ( j , k ) ∗ ( − 1 ) k ∗ ( j − k ) i \frac{1}{j!}\sum_{k = 0}^jc(j,k) *(-1)^k*(j - k)^i j!1∑k=0jc(j,k)∗(−1)k∗(j−k)i
将第二类斯特林数的通项公式带入:
∑
i
=
0
n
∑
j
=
0
i
S
(
i
,
j
)
∗
2
j
∗
(
j
!
)
\sum_{i = 0}^n\sum_{j = 0}^iS(i,j) * 2^j *(j!)
i=0∑nj=0∑iS(i,j)∗2j∗(j!)
=
∑
j
=
0
n
2
j
∗
(
j
!
)
∑
i
=
0
n
1
k
!
∑
k
=
0
j
c
(
j
,
k
)
∗
(
−
1
)
k
∗
(
j
−
k
)
i
=\sum_{j = 0}^n2^j*(j!)\sum_{i = 0}^n\frac{1}{k!}\sum_{k = 0}^jc(j,k) *(-1)^k*(j - k)^i
=j=0∑n2j∗(j!)i=0∑nk!1k=0∑jc(j,k)∗(−1)k∗(j−k)i
=
∑
j
=
0
n
2
j
∗
(
j
!
)
∑
k
=
0
j
(
−
1
)
k
k
!
∑
i
=
0
n
(
j
−
k
)
i
(
j
−
k
)
!
=\sum_{j = 0}^n2^j*(j!)\sum_{k = 0}^j\frac{(-1)^k}{k!}\sum_{i = 0}^n\frac{(j - k)^i}{(j-k)!}
=j=0∑n2j∗(j!)k=0∑jk!(−1)ki=0∑n(j−k)!(j−k)i
右边是一个类似卷积形式的式子,令
h
(
n
)
=
(
−
1
)
n
n
!
,
g
(
n
)
=
∑
i
=
0
n
n
i
n
!
h(n) = \frac{(-1)^n}{n!},g(n) =\sum_{i = 0}^n\frac{n^i}{n!}
h(n)=n!(−1)n,g(n)=∑i=0nn!ni
最终式子为:
∑
j
=
0
n
2
j
(
j
!
)
∑
k
=
0
j
h
(
k
)
∗
g
(
j
−
k
)
\sum_{j = 0}^n2^j(j!)\sum_{k = 0}^jh(k)*g(j-k)
∑j=0n2j(j!)∑k=0jh(k)∗g(j−k)
于是就可以NTT了,复杂度
O
(
n
log
n
)
O(n \log n)
O(nlogn)
代码:
#include<bits/stdc++.h>
using namespace std;
const int mod = 998244353;
const int maxn = 1e6 + 10;
typedef long long ll;
ll n;
ll fact[maxn],ifact[maxn],pw[maxn];
ll h[maxn],g[maxn];
ll fpow(ll a,ll b) {
ll r = 1;
while(b) {
if(b & 1) r = r * a % mod;
a = a * a % mod;
b >>= 1;
}
return r;
}
void change(ll t[],int len) {
for(int i = 1, j = len / 2; i < len - 1; i++) {
if(i < j) swap(t[i],t[j]);
int k = len / 2;
while(j >= k) {
j -= k;
k /= 2;
}
if(j < k) j += k;
}
}
void NTT(ll t[],int len,int type) {
change(t,len);
for(int s = 2; s <= len; s <<= 1) {
ll wn = fpow(3,(mod - 1) / s);
if(type == -1) wn = fpow(wn,mod - 2);
for(int j = 0; j < len; j += s) {
ll w = 1;
for(int k = 0; k < s / 2; k++) {
ll u = t[j + k],v = t[j + k + s / 2] * w % mod;
t[j + k] = (u + v) % mod;
t[j + k + s / 2] = (u - v + mod) % mod;
w = w * wn % mod;
}
}
}
if(type == -1) {
ll inv = fpow(len,mod - 2);
for(int i = 0; i < len; i++)
t[i] = t[i] * inv % mod;
}
}
ll cal(ll x) {
x %= mod;
if(x == 1) return n + 1;
if(x == 0) return 1;
ll t = x - 1;
if(t < 0) t += mod;
return (fpow(x,n + 1) - 1 + mod) % mod * fpow(t,mod - 2) % mod;
}
int main() {
scanf("%lld",&n);
fact[0] = 1;
for(int i = 1; i <= maxn - 10; i++)
fact[i] = fact[i - 1] * i % mod;
ifact[maxn - 10] = fpow(fact[maxn - 10],mod - 2);
for(int i = maxn - 11; i >= 0; i--)
ifact[i] = ifact[i + 1] * (i + 1) % mod;
pw[0] = 1;
for(int i = 1; i <= maxn - 10; i++)
pw[i] = pw[i - 1] * 2 % mod;
for(int i = 0; i <= n; i++) {
h[i] = ifact[i];
if(i & 1) h[i] *= -1;
if(h[i] < 0) h[i] += mod;
g[i] = ifact[i] * cal(i) % mod;
}
int len = 1;
while(len <= 2 * (n + 1)) len <<= 1;
NTT(h,len,1);NTT(g,len,1);
for(int i = 0; i < len; i++)
h[i] = 1ll * h[i] * g[i] % mod;
NTT(h,len,-1);
ll ans = 0;
for(int i = 0,j = 1; i <= n; i++,j = (j + j) % mod)
ans = (ans + pw[i] * fact[i] % mod * h[i] % mod) % mod;
printf("%lld\n",ans);
return 0;
}