题目链接:https://www.luogu.com.cn/problem/P4091
提供一种 O ( N ) O(N) O(N)的做法,需要用到具体数学里的有限微积分
首先是化简原式:
∑
i
=
0
n
∑
j
=
0
i
S
(
i
,
j
)
∗
2
j
∗
(
j
!
)
\sum_{i=0}^n\sum_{j=0}^iS(i,j)*2^j*(j!)
i=0∑nj=0∑iS(i,j)∗2j∗(j!)
=
∑
i
=
0
n
∑
j
=
0
n
S
(
i
,
j
)
∗
2
j
∗
(
j
!
)
=\sum_{i=0}^n\sum_{j=0}^nS(i,j)*2^j*(j!)
=i=0∑nj=0∑nS(i,j)∗2j∗(j!)
=
∑
i
=
0
n
∑
j
=
0
n
2
j
∑
k
=
0
n
(
−
1
)
j
−
k
(
j
k
)
k
i
=\sum_{i=0}^n\sum_{j=0}^n2^j\sum_{k=0}^n(-1)^{j-k}\binom{j}{k}k^i
=i=0∑nj=0∑n2jk=0∑n(−1)j−k(kj)ki
=
∑
k
=
0
n
(
−
1
)
k
(
∑
i
=
0
n
k
i
)
∗
∑
j
=
0
n
(
−
1
)
j
(
j
k
)
2
j
=\sum_{k=0}^n(-1)^k(\sum_{i=0}^nk^i)*\sum_{j=0}^n(-1)^{j}\binom{j}{k}2^j
=k=0∑n(−1)k(i=0∑nki)∗j=0∑n(−1)j(kj)2j
然后就是有限微积分的骚操作
我们考虑如何求后面的式子
∑
j
=
0
n
(
j
k
)
(
−
2
)
j
\sum_{j=0}^n\binom{j}{k}(-2)^j
j=0∑n(kj)(−2)j
令
u
=
(
j
k
)
u=\binom{j}{k}
u=(kj),
Δ
v
=
(
−
2
)
j
\Delta v=(-2)^j
Δv=(−2)j
∑
(
j
k
)
(
−
2
)
j
δ
j
=
(
−
2
)
j
(
j
k
)
−
3
−
−
2
−
3
∑
(
−
2
)
j
(
j
k
−
1
)
δ
j
\sum\binom{j}{k}(-2)^j\delta j=\frac{(-2)^j\binom{j}{k}}{-3}-\frac{-2}{-3}\sum(-2)^j\binom{j}{k-1}\delta j
∑(kj)(−2)jδj=−3(−2)j(kj)−−3−2∑(−2)j(k−1j)δj
把这个不定和式转化为定和式之后递推即可
C o d e Code Code
#include <iostream>
#include <cstdlib>
#include <cstdio>
#include <cstring>
using namespace std;
#define int long long
const int MAXN = 1e5, Mod = 998244353;
int idk[MAXN + 10], vis[MAXN + 10], prime[MAXN + 10], inv[MAXN + 10], tot = 0;
int fastpow(int x, int p){
int ans = 1;
while (p){
if (p & 1) ans = ans * x % Mod;
x = x * x % Mod;
p >>= 1;
}
return ans;
}
void get_idk(int n, int k){
idk[1] = inv[1] = 1;
for (register int i = 2; i <= n; ++i){
if (!vis[i]){
prime[++tot] = i, vis[i] = i;
idk[i] = fastpow(i, k);
inv[i] = fastpow(i, Mod - 2);
}
for (register int j = 1; j <= tot && i * prime[j] <= n; ++j){
if (vis[i] > prime[j]) break;
vis[i * prime[j]] = prime[j];
idk[i * prime[j]] = idk[i] * idk[prime[j]] % Mod;
inv[i * prime[j]] = inv[i] * inv[prime[j]] % Mod;
}
}
}
signed main(){
//freopen ("better.in", "r", stdin);
//freopen ("better.out", "w", stdout);
int n;
scanf("%lld\n", &n);
get_idk(n, n + 1);
int p = -1 * fastpow(-2, n + 1) * inv[3] % Mod, p2 = 2 * inv[3] % Mod;
int f0 = -inv[3], fn = p, x = -1;
int ans = (fn - f0) % Mod, c = n + 1;
f0 = -p2 * f0 % Mod, fn = (p * c - p2 * fn) % Mod;
ans += x * (n + 1) % Mod * (fn - f0) % Mod;
for (register int i = 2; i <= n; ++i){
x *= -1;
c = (n - i + 2) * inv[i] % Mod * c % Mod;
fn = (p * c - p2 * fn) % Mod, f0 = -p2 * f0 % Mod;
ans += x * (idk[i] - 1) * inv[i - 1] % Mod * (fn - f0) % Mod;
ans %= Mod;
}
printf("%lld\n", ans < 0? ans + Mod : ans);
/*
int n, f, s;
scanf("%lld", &n);
get_idk(n, n + 1);
int c = 1, p = (n + 1) % 2? idk[2] : Mod - idk[2];
p = p * inv[3] % Mod;
f = (p - 1) * inv[3] % Mod;
int ans = f;
for (register int i = 1; i <= n; ++i){
c = (n - i + 2) * inv[i] % Mod * c % Mod;
f = (c * p % Mod - 2 * inv[3] % Mod * f % Mod);
int x = i % 2? Mod - 1 : 1;
if (!i) s = 1;
else if (i == 1) s = n + 1;
else s = idk[i] * inv[i - 1] % Mod;
int tmp = x * s % Mod * f % Mod;
ans = (ans + tmp) % Mod;
}
cout << ans << endl;
*/
return 0;
}