题面:bzoj4555
题意:求 ∑ i = 0 n ∑ j = 0 i S ( i , j ) ⋅ j ! ⋅ 2 j \sum _{i=0} ^n \sum_{j=0} ^i S(i, j) \cdot j! \cdot 2^j ∑i=0n∑j=0iS(i,j)⋅j!⋅2j mod 998244353 998244353 998244353,其中 S ( i , j ) S(i,j) S(i,j)是第二类斯特林数。
题解:第二类斯特林数 S ( i , j ) S(i, j) S(i,j)表示将 i i i个球放入 j j j个一样的盒子的方案数。
考虑容斥:枚举有几个空盒子,答案即为无空盒子 − - − 至少一个空盒子 + + + 至少两个空盒子 ⋯ \cdots ⋯
最后因为 j j j个盒子完全相同,除以 j ! j! j!
S ( i , j ) = 1 j ! ∑ k = 0 j ( − 1 ) k ⋅ ( k j ) ⋅ ( j − k ) i S(i, j) = \frac 1 {j!} \sum _{k=0} ^j (-1)^k \cdot {k \choose j} \cdot (j - k)^i S(i,j)=j!1∑k=0j(−1)k⋅(jk)⋅(j−k)i
a n s = ∑ i = 0 n ∑ j = 0 n S ( i , j ) ⋅ j ! ⋅ 2 j = ∑ i = 0 n ∑ j = 0 n j ! ⋅ 2 j ⋅ 1 j ! ∑ k = 0 j ( − 1 ) k ( k j ) ( j − k ) i = ∑ j = 0 n 2 j ∑ k = 0 j ( − 1 ) k ( k j ) ∑ i = 0 n ( j − k ) i = ∑ j = 0 n 2 j ∑ k = 0 j ( − 1 ) k j ! k ! ( j − k ) ! ( j − k ) n + 1 − 1 j − k − 1 = ∑ j = 0 n 2 j ⋅ j ! ⋅ ∑ k = 0 j ( − 1 ) k k ! ⋅ ( j − k ) n + 1 − 1 ( j − k ) ! ( j − k − 1 ) \begin {aligned} ans &= \sum _{i=0} ^n \sum _{j=0} ^n S(i, j) \cdot j! \cdot 2^j \\ &= \sum _{i=0} ^n \sum _{j=0}^n j! \cdot 2^j \cdot \frac 1 {j!} \sum _{k=0} ^j (-1)^k {k \choose j}(j-k) ^i \\ &= \sum _{j=0} ^n 2^j \sum _{k=0} ^j (-1) ^k {k \choose j} \sum _{i=0} ^n (j-k) ^i \\ &= \sum _{j=0} ^n 2^j \sum _{k=0} ^j (-1)^k \frac {j!} {k! (j - k)!} \frac {(j-k) ^{n + 1} - 1} {j - k - 1} \\ &= \sum _{j=0} ^n 2^j \cdot j! \cdot \sum _{k=0} ^j \frac {(-1)^k} {k!} \cdot \frac {(j-k) ^{n+1} - 1} {(j-k)! (j - k - 1)} \end {aligned} ans=i=0∑nj=0∑nS(i,j)⋅j!⋅2j=i=0∑nj=0∑nj!⋅2j⋅j!1k=0∑j(−1)k(jk)(j−k)i=j=0∑n2jk=0∑j(−1)k(jk)i=0∑n(j−k)i=j=0∑n2jk=0∑j(−1)kk!(j−k)!j!j−k−1(j−k)n+1−1=j=0∑n2j⋅j!⋅k=0∑jk!(−1)k⋅(j−k)!(j−k−1)(j−k)n+1−1
后半部分用NTT, O ( n l o g n ) O(nlogn) O(nlogn)预处理。整个 O ( n ) O(n) O(n)计算。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll ;
const int maxn = 5e5 + 10, mod = 998244353, g = 3 ;
int n, ginv ;
ll fac[maxn], inv[maxn] ;
ll a[maxn], b[maxn] ;
ll power (ll x, int y) {
ll res = 1 ;
while (y) {
if (y & 1) res = res * x % mod ;
x = x * x % mod; y >>= 1 ;
}
return res ;
}
void ntt (ll a[], int n, int f) {
for (int i = 0, j = 0; i < n; i ++) {
if (i > j) swap (a[i], a[j]) ;
for (int t = n >> 1; (j ^= t) < t; t >>= 1) ;
}
for (int i = 2; i <= n; i <<= 1) {
ll wn = power (f ? ginv : g, (mod - 1) / i) ;
for (int j = 0; j < n; j += i) {
ll w = 1 ;
for (int k = 0; k < (i >> 1); k ++, w = w * wn % mod) {
ll A = a[j + k], B = w * a[j + k + (i >> 1)] % mod ;
a[j + k] = (A + B) % mod; a[j + k + (i >> 1)] = (A - B + mod) % mod ;
}
}
}
}
int main() {
cin >> n ;
ginv = power (g, mod - 2) ;
fac[0] = inv[0] = 1 ;
for (int i = 1; i <= n; i ++) fac[i] = fac[i - 1] * i % mod ;
inv[n] = power (fac[n], mod - 2) ;
for (int i = n - 1; i >= 1; i --) inv[i] = inv[i + 1] * (i + 1) % mod ;
for (int i = 0; i <= n; i ++) a[i] = 1ll * (i & 1 ? -1 : 1) * inv[i] % mod ;
for (int i = 2; i <= n; i ++) b[i] = (power (i, n + 1) - 1 + mod) % mod * inv[i] % mod * power (i - 1, mod - 2) % mod ;
b[0] = 1; b[1] = n + 1 ;
int m = 2 * n ;
for (n = 1; n <= m; n <<= 1) ;
ntt (a, n, 0); ntt (b, n, 0) ;
for (int i = 0; i < n; i ++) a[i] = a[i] * b[i] % mod ;
ntt (a, n, 1) ;
ll invn = power (n, mod - 2) ;
for (int i = 0; i < n; i ++) a[i] = a[i] * invn % mod ;
ll p = 1, ans = 0 ;
for (int i = 0; i <= n; i ++, p = (p + p) % mod) ans = (ans + p * fac[i] % mod * a[i] % mod) % mod ;
printf("%lld\n", ans) ;
return 0 ;
}