Description
给定 n n ,求
Solution
推式子
f(n)=∑i=0n∑j=0iS(i,j)×2j×j!
f
(
n
)
=
∑
i
=
0
n
∑
j
=
0
i
S
(
i
,
j
)
×
2
j
×
j
!
f(n)=∑j=0n2j×j!∑i=0nS(i,j)
f
(
n
)
=
∑
j
=
0
n
2
j
×
j
!
∑
i
=
0
n
S
(
i
,
j
)
f(n)=∑j=0n2j×j!∑i=0n∑k=0n(−1)k(j−k)ik!(j−k)!
f
(
n
)
=
∑
j
=
0
n
2
j
×
j
!
∑
i
=
0
n
∑
k
=
0
n
(
−
1
)
k
(
j
−
k
)
i
k
!
(
j
−
k
)
!
f(n)=∑j=0n2j×j!∑k=0n(−1)k∑ni=0(j−k)ik!(j−k)!
f
(
n
)
=
∑
j
=
0
n
2
j
×
j
!
∑
k
=
0
n
(
−
1
)
k
∑
i
=
0
n
(
j
−
k
)
i
k
!
(
j
−
k
)
!
其中 ∑ni=0(j−k)i ∑ i = 0 n ( j − k ) i 是等比数列,用公式求,其他的 NTT N T T 即可。
#include <bits/stdc++.h>
using namespace std;
typedef long long lint;
const int mod = 998244353;
const int G = 3, Phi = mod - 1;
const int maxn = 300005;
int n, fac[maxn], rat[maxn], ifac[maxn], ans;
int A[maxn], B[maxn], R[maxn], m, L;
int Pow(int x, int k)
{
int res = 1;
while (k) {
if (k & 1) res = (lint)res * x % mod;
x = (lint)x * x % mod; k >>= 1;
}
return res;
}
void NTT(int *a, int f)
{
for (int i = 0; i < n; ++i)
if (i < R[i]) swap(a[i], a[R[i]]);
for (int i = 1; i < n; i <<= 1) {
int wn = Pow(G, Phi / (i << 1)), t;
if (f == -1) wn = Pow(wn, mod - 2);
for (int j = 0; j < n; j += (i << 1)) {
int w = 1;
for (int k = 0; k < i; ++k, w = (lint)w * wn % mod) {
t = (lint)a[j + i + k] * w % mod;
a[j + i + k] = a[j + k] - t;
if (a[j + i + k] < 0) a[j + i + k] += mod;
a[j + k] = a[j + k] + t;
if (a[j + k] >= mod) a[j + k] -= mod;
}
}
}
}
void NTT()
{
m = n << 1;
for (n = 1; n <= m; n <<= 1) ++L;
for (int i = 0; i < n; ++i) R[i] = (R[i >> 1] >> 1) | ((i & 1) << (L - 1));
NTT(A, 1); NTT(B, 1);
for (int i = 0; i < n; ++i) A[i] = (lint)A[i] * B[i] % mod;
NTT(A, -1);
for (int Inv = Pow(n, mod - 2), i = 0; i < n; ++i) A[i] = (lint)A[i] * Inv % mod;
n = m >> 1;
}
int main()
{
scanf("%d", &n);
fac[0] = 1; rat[0] = 1;
for (int i = 1; i <= n; ++i) {
fac[i] = (lint)fac[i - 1] * i % mod;
rat[i] = (lint)(Pow(i, n + 1) - 1) * Pow(i - 1, mod - 2) % mod;
if (rat[i] < 0) rat[i] += mod;
}
rat[1] = n + 1;
ifac[n] = Pow(fac[n], mod - 2);
for (int i = n - 1; i >= 0; --i)
ifac[i] = (lint)ifac[i + 1] * (i + 1) % mod;
for (int i = 0; i <= n; ++i) {
A[i] = ifac[i];
if (i & 1) A[i] = mod - A[i];
B[i] = (lint)rat[i] * ifac[i] % mod;
}
NTT();
for (int i = 0, sum = 1; i <= n; ++i) {
ans += (lint)sum * fac[i] % mod * A[i] % mod;
if (ans >= mod) ans -= mod;
sum = (lint)sum * 2 % mod;
}
printf("%d\n", ans);
return 0;
}