【题目链接】
https://www.lydsy.com/JudgeOnline/problem.php?id=5093
【题解】
首先每个点都是独立的,可以求出一个点的贡献再把它乘以
n
n
,枚举这个点连了多少条边,可以列出式子:
考虑第二类斯特林数的一个恒等式:
Xn=∑Xi=0(Xi)i!Sn,i
X
n
=
∑
i
=
0
X
(
i
X
)
i
!
S
n
,
i
将它代入:
ans=n∗2(n−12)∑n−1i=0(n−1i)∑ij=0(ij)j!Sk,j
a
n
s
=
n
∗
2
(
2
n
−
1
)
∑
i
=
0
n
−
1
(
i
n
−
1
)
∑
j
=
0
i
(
j
i
)
j
!
S
k
,
j
现在考虑求:
∑n−1i=0(n−1i)∑ij=0(ij)j!Sk,j
∑
i
=
0
n
−
1
(
i
n
−
1
)
∑
j
=
0
i
(
j
i
)
j
!
S
k
,
j
先改变求和顺序:
∑n−1j=0j!Sk,j∑n−1i=j(n−1i)(ij)
∑
j
=
0
n
−
1
j
!
S
k
,
j
∑
i
=
j
n
−
1
(
i
n
−
1
)
(
j
i
)
考虑后面的和式的意义:先从
n−1
n
−
1
个数中取出
j
j
个数,再从个数中取出
i
i
个数。
因此可以化为:从个数中选出
j
j
个数,其他的数是否选取随意。于是可以转换为:
由于
Sk,n(n>k)=0
S
k
,
n
(
n
>
k
)
=
0
所以只要求第
k
k
行的前个斯特林数即可。
对于恒等式做二项式反演:
i!Sk,i=∑ij=0(−1)i−j(ij)ik
i
!
S
k
,
i
=
∑
j
=
0
i
(
−
1
)
i
−
j
(
j
i
)
i
k
Sk,i=∑ij=0(−1)i−j(i−j)!∗iki!
S
k
,
i
=
∑
j
=
0
i
(
−
1
)
i
−
j
(
i
−
j
)
!
∗
i
k
i
!
NTT即可。
时间复杂度
O(KlogK)
O
(
K
l
o
g
K
)
【代码】
/* - - - - - - - - - - - - - - -
User : VanishD
problem : [bzoj5093]
Points : NTT + stirling
- - - - - - - - - - - - - - - */
# include <bits/stdc++.h>
# define ll long long
# define N 1000100
using namespace std;
const int inf = 0x3f3f3f3f, INF = 0x7fffffff, P = 998244353, G = 3;
const ll infll = 0x3f3f3f3f3f3f3f3fll, INFll = 0x7fffffffffffffffll;
int read(){
int tmp = 0, fh = 1; char ch = getchar();
while (ch < '0' || ch > '9'){ if (ch == '-') fh = -1; ch = getchar(); }
while (ch >= '0' && ch <= '9'){ tmp = tmp * 10 + ch - '0'; ch = getchar(); }
return tmp * fh;
}
int power(int x, ll y){
int i = x; x = 1;
while (y > 0){
if (y % 2 == 1) x = 1ll * i * x % P;
i = 1ll * i * i % P;
y /= 2;
}
return x;
}
void NTT(int *a, int l, int tag){
for (int i = 0, j = 0; i < l; i++){
if (i < j) swap(a[i], a[j]);
for (int k = (l >> 1); (j ^= k) < k; k >>= 1);
}
for (int i = 1; i < l; i *= 2){
int wn = power(G, (P - 1) / (i * 2));
if (tag == -1) wn = power(wn, P - 2);
for (int j = 0; j < l; j += i * 2)
for (int k = 0, w = 1; k < i; k++, w = 1ll * w * wn % P){
int x = a[k + j], y = 1ll * w * a[k + i + j] % P;
a[k + j] = (x + y) % P; a[k + i + j] = (x - y) % P;
}
}
if (tag == -1){
int inv = power(l, P - 2);
for (int i = 0; i < l; i++) a[i] = 1ll * a[i] * inv % P;
}
}
int powk[N], mul[N], a[N], b[N], C[N], n, k, l, S[N];
void getS(int n){
powk[0] = 1, mul[0] = 1;
for (int i = 1; i <= n; i++) mul[i] = 1ll * mul[i - 1] * i % P;
for (int i = 0; i <= n; i++){
a[i] = power(-1, i) * power(mul[i], P - 2);
b[i] = 1ll * power(i, n) * power(mul[i], P - 2) % P;
}
l = 1;
while (l <= n * 2) l <<= 1;
NTT(a, l, 1), NTT(b, l, 1);
for (int i = 0; i < l; i++) a[i] = 1ll * a[i] * b[i] % P;
NTT(a, l, -1);
for (int i = 0; i <= n; i++) S[i] = (a[i] + P) % P;
}
int main(){
// freopen(".in", "r", stdin);
// freopen(".out", "w", stdout);
n = read(), k = read();
getS(k);
int lim = min(n - 1, k), ans = 0;
C[0] = 1; for (int i = 1; i <= k; i++) C[i] = 1ll * C[i - 1] * (n - i) % P * power(i, P - 2) % P;
for (int i = 0; i <= lim; i++)
ans = (ans + 1ll * S[i] * mul[i] % P * C[i] % P * power(2, n - 1 - i)) % P;
ans = (ans + P) % P;
ll tmp = 1ll * (n - 1)* (n - 2) / 2;
ans = 1ll * ans * n % P *power(2, tmp) % P;
printf("%d\n", ans);
return 0;
}