Address
Solution
-
a n s = ∑ i = 0 n ∑ j = 0 i S ( i , j ) ∗ 2 j ∗ ( j ! ) ans=\sum_{i=0}^{n}\sum_{j=0}^{i}S(i,j)*2^j*(j!) ans=i=0∑nj=0∑iS(i,j)∗2j∗(j!)
-
因为 i > j i>j i>j 时, S ( i , j ) = 0 S(i,j)=0 S(i,j)=0,所以:
a n s = ∑ i = 0 n ∑ j = 0 n S ( i , j ) ∗ 2 j ∗ ( j ! ) ans=\sum_{i=0}^{n}\sum_{j=0}^{n}S(i,j)*2^j*(j!) ans=i=0∑nj=0∑nS(i,j)∗2j∗(j!) -
众所周知:
S ( i , j ) = 1 j ! ∑ k = 0 j ( − 1 ) k ∗ ( j − k ) i ∗ C j k S(i,j)=\frac{1}{j!}\sum_{k=0}^{j}(-1)^k*(j-k)^i*C_j^k S(i,j)=j!1k=0∑j(−1)k∗(j−k)i∗Cjk
因此:
a n s = ∑ i = 0 n ∑ j = 0 n ∑ k = 0 j ( − 1 ) k ∗ ( j − k ) i ∗ C j k ∗ 2 j ans=\sum_{i=0}^{n}\sum_{j=0}^{n}\sum_{k=0}^{j}(-1)^k*(j-k)^i*C_j^k*2^j ans=i=0∑nj=0∑nk=0∑j(−1)k∗(j−k)i∗Cjk∗2j -
发现 2 j 2^j 2j 只包含了变量 j j j,所以把它提到前面:
a n s = ∑ j = 0 n 2 j ∗ ∑ i = 0 n ∑ k = 0 j ( − 1 ) k ∗ ( j − k ) i ∗ C j k ans=\sum_{j=0}^{n}2^j*\sum_{i=0}^{n}\sum_{k=0}^{j}(-1)^k*(j-k)^i*C_j^k ans=j=0∑n2j∗i=0∑nk=0∑j(−1)k∗(j−k)i∗Cjk -
然后把 C j k C_j^k Cjk 拆成阶乘形式,再整理得:
a n s = ∑ j = 0 n 2 j ∗ ( j ! ) ∗ ∑ k = 0 j ∗ ( − 1 ) k k ! ∗ ∑ i = 0 n ( j − k ) i ( j − k ) ! ans=\sum_{j=0}^{n}2^j*(j!)*\sum_{k=0}^j*\frac{(-1)^k}{k!}*\frac{\sum_{i=0}^{n}(j-k)^i}{(j-k)!} ans=j=0∑n2j∗(j!)∗k=0∑j∗k!(−1)k∗(j−k)!∑i=0n(j−k)i -
于是令 f ( i ) = ( − 1 ) i i ! , g ( j ) = ∑ i = 0 n j i j ! f(i)=\frac{(-1)^i}{i!},g(j)=\frac{\sum_{i=0}^nj^i}{j!} f(i)=i!(−1)i,g(j)=j!∑i=0nji
-
显然 g ( j ) g(j) g(j) 可以用等比数列求和公式变成:
j n + 1 − 1 j ! ( j − 1 ) \frac{j^{n+1}-1}{j!(j-1)} j!(j−1)jn+1−1 -
那么用 N T T NTT NTT 把 f f f 和 g g g 乘起来就行了。
Code
#include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;
const int e = 1e6 + 5, mod = 998244353;
int a[e], lim = 1, rev[e], b[e], n, ans, fa[e], g[e], cc[e], dd[e];
inline int ksm(int x, int y)
{
int res = 1;
while (y)
{
if (y & 1) res = 1ll * res * x % mod;
y >>= 1;
x = 1ll * x * x % mod;
}
return res;
}
inline void fft(int n, int *a, int op)
{
int i, j, k, r = (op == 1 ? 3 : 998244354 / 3);
for (i = 0; i < n; i++)
if (i < rev[i]) swap(a[i], a[rev[i]]);
for (k = 1; k < n; k <<= 1)
{
int w0 = ksm(r, (mod - 1) / (k << 1));
for (i = 0; i < n; i += (k << 1))
{
int w = 1;
for (j = 0; j < k; j++)
{
int b = a[i + j], c = 1ll * w * a[i + j + k] % mod;
a[i + j] = (b + c) % mod;
a[i + j + k] = (b - c + mod) % mod;
w = 1ll * w * w0 % mod;
}
}
}
}
int main()
{
cin >> n;
int i, k = 0, fac = 1;
while (lim < n * 2)
{
lim <<= 1;
k++;
}
for (i = 1; i < lim; i++)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << k - 1);
for (i = 0; i <= n; i++)
{
if (i != 0) fac = 1ll * fac * i % mod;
if (i & 1) a[i] = mod - 1;
else a[i] = 1;
a[i] = 1ll * a[i] * ksm(fac, mod - 2) % mod;
if (i == 0) b[i] = 1;
else if (i == 1) b[i] = n + 1;
else
b[i] = 1ll * (ksm(i, n + 1) + mod - 1) % mod * ksm(i - 1, mod - 2) % mod
* ksm(fac % mod, mod - 2) % mod;
int j;
fft(lim, a, 1);
fft(lim, b, 1);
for (i = 0; i < lim; i++) a[i] = 1ll * a[i] * b[i] % mod;
fft(lim, a, -1);
for (i = 0; i < lim; i++) a[i] = 1ll * a[i] * ksm(lim, mod - 2) % mod;
int p = 1;
fac = 1;
for (i = 0; i <= n; i++)
{
if (i != 0) fac = 1ll * fac * i % mod;
int c = a[i];
ans = (ans + 1ll * c * fac % mod * p) % mod;
p = 2ll * p % mod;
}
cout << ans << endl;
fclose(stdin);
fclose(stdout);
return 0;
}