Description
在一个
s
s
s个点的图中,存在
s
−
n
s−n
s−n条边,使图中形成了
n
n
n个连通块,第
i
i
i个连通块中有
a
i
a_i
ai个点。
现在我们需要再连接
n
−
1
n−1
n−1条边,使该图变成一棵树。对一种连边方案,设原图中第
i
i
i个连通块连出了
d
i
d_i
di条边,那么这棵树
T
T
T的价值为:
v
a
l
(
T
)
=
(
∏
i
=
1
n
d
i
m
)
(
∑
i
=
1
n
d
i
m
)
val(T)=(\prod_{i=1}^nd_i^m)(\sum_{i=1}^nd_i^m)
val(T)=(i=1∏ndim)(i=1∑ndim)你的任务是求出所有可能的生成树的价值之和,对
998244353
998244353
998244353取模。
Sample Input
3 1
2 3 4
Sample Output
1728
首先一个
a
i
a_i
ai的贡献就是
a
i
d
i
a_i^{d_i}
aidi。
先考虑这个式子的形式一个
s
u
m
sum
sum乘上一个乘积太难搞了。
我们把它变成有一个
i
i
i可以多乘上一个
d
i
m
d_i^m
dim,那么你就可以设:
f
[
i
]
[
j
]
f[i][j]
f[i][j]为
i
i
i个点已经填了
p
u
f
e
r
pufer
pufer序中的
j
j
j个点的方案,
g
[
i
]
[
j
]
g[i][j]
g[i][j]为
i
i
i个点已经填了
p
u
f
e
r
pufer
pufer序中的
j
j
j个点,有了一个特殊点方案,
每次枚举当前这个点加进来多少点转移即可。
这个
d
i
d_i
di枚举的范围太大了,不是很好搞,我们考虑一些组合意义来优化。
d
i
m
d_i^m
dim相当于有
d
i
d_i
di个格子,你可以染
m
m
m次颜色,每次可以选任意一个格子涂颜色。但是你一个点的度数实际上是
d
i
+
1
d_i + 1
di+1,所以你每一个点都要在后面再额外预留一个格子。
你发现这样有颜色的格子是不超过
m
m
m个的,状态改变为:
f
[
i
]
[
j
]
f[i][j]
f[i][j]表示
i
i
i个颜色确定前半部分的
j
j
j个格子,你枚举前半部分填了几个有颜色的格子,可得转移:
f
[
i
]
[
j
]
=
∑
k
=
0
m
f
[
i
−
1
]
[
j
−
k
]
∗
C
n
−
2
−
j
+
k
k
∗
a
i
k
+
1
(
S
m
k
∗
k
!
+
S
m
k
+
1
∗
(
k
+
1
)
!
)
f[i][j]=\sum_{k=0}^mf[i-1][j-k]*C_{n-2-j+k}^k*a_i^{k+1}(S_m^k*k!+S_m^{k+1}*(k+1)!)
f[i][j]=k=0∑mf[i−1][j−k]∗Cn−2−j+kk∗aik+1(Smk∗k!+Smk+1∗(k+1)!)
g
[
i
]
[
j
]
g[i][j]
g[i][j]表示
i
i
i个格子确定了
j
j
j是有颜色的,已经确定了一个特殊点,可得转移:
g
[
i
]
[
j
]
=
∑
k
=
0
m
g
[
i
−
1
]
[
j
−
k
]
∗
C
n
−
2
−
j
+
k
k
∗
a
i
k
+
1
(
S
m
k
∗
k
!
+
S
m
k
+
1
∗
(
k
+
1
)
!
)
g[i][j]=\sum_{k=0}^mg[i-1][j-k]*C_{n-2-j+k}^k*a_i^{k+1}(S_m^k*k!+S_m^{k+1}*(k+1)!)
g[i][j]=k=0∑mg[i−1][j−k]∗Cn−2−j+kk∗aik+1(Smk∗k!+Smk+1∗(k+1)!)
g
[
i
]
[
j
]
+
=
∑
k
=
0
2
m
f
[
i
−
1
]
[
j
−
k
]
∗
C
n
−
2
−
j
+
k
k
∗
a
i
k
+
1
(
S
2
m
k
∗
k
!
+
S
2
m
k
+
1
∗
(
k
+
1
)
!
)
g[i][j]+=\sum_{k=0}^{2m}f[i-1][j-k]*C_{n-2-j+k}^k*a_i^{k+1}(S_{2m}^k*k!+S_{2m}^{k+1}*(k+1)!)
g[i][j]+=k=0∑2mf[i−1][j−k]∗Cn−2−j+kk∗aik+1(S2mk∗k!+S2mk+1∗(k+1)!)
那个斯特林数就是分别考虑后半部分填或不填。
最后再考虑把那些每个
a
i
a_i
ai占了但不涂颜色的位置都给乘回来,可得答案:
a
n
s
=
∑
i
=
0
n
−
2
g
[
n
]
[
i
]
∗
(
∑
j
=
1
n
a
j
)
n
−
2
−
i
ans=\sum_{i=0}^{n-2}g[n][i]*(\sum_{j=1}^na_j)^{n-2-i}
ans=i=0∑n−2g[n][i]∗(j=1∑naj)n−2−i
这个
d
p
dp
dp式直接用分治
N
T
T
NTT
NTT优化即可做到
O
(
n
m
l
o
g
2
n
)
O(nmlog^2n)
O(nmlog2n)
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
int _max(int x, int y) {return x > y ? x : y;}
int _min(int x, int y) {return x < y ? x : y;}
const int mod = 998244353;
const int N = 30001, M = 31;
int read() {
int s = 0, f = 1; char ch = getchar();
while(ch < '0' || ch > '9') {if(ch == '-') f = -1; ch = getchar();}
while(ch >= '0' && ch <= '9') s = s * 10 + ch - '0', ch = getchar();
return s * f;
}
void put(int x) {
if(x >= 10) put(x / 10);
putchar(x % 10 + '0');
}
int jc[N], inv[N];
int n, m, R[N << 2], S[M * 2][M * 2], d1[N][M * 2], d2[N][M * 2];
int a[N], f[20][N << 2], g[20][N << 2];
int pow_mod(int a, int k) {
int ans = 1;
while(k) {
if(k & 1) ans = (LL)ans * a % mod;
a = (LL)a * a % mod, k /= 2;
} return ans;
}
int add(int x, int y) {
x += y;
return x >= mod ? x - mod : x;
}
int dec(int x, int y) {
x -= y;
return x < 0 ? x + mod : x;
}
void NTT(int y[], int len, int on) {
for(int i = 0; i < len; i++) if(i < R[i]) swap(y[i], y[R[i]]);
for(int i = 1; i < len; i <<= 1) {
int wn = pow_mod(3, (mod - 1) / (i << 1));
if(on == -1) wn = pow_mod(wn, mod - 2);
for(int j = 0; j < len; j += i << 1) {
int w = 1;
for(int k = 0; k < i; k++, w = (LL)w * wn % mod) {
int u = y[j + k], v = (LL)y[j + k + i] * w % mod;
y[j + k] = add(u, v), y[j + k + i] = dec(u, v);
}
}
} if(on == -1) {
int hh = pow_mod(len, mod - 2);
for(int i = 0; i < len; i++) y[i] = (LL)y[i] * hh % mod;
}
}
void pre() {
jc[0] = inv[0] = 1; for(int i = 1; i <= n; i++) jc[i] = (LL)jc[i - 1] * i % mod;
inv[n] = pow_mod(jc[n], mod - 2); for(int i = n - 1; i >= 1; i--) inv[i] = (LL)inv[i + 1] * (i + 1) % mod;
S[0][0] = 1;
for(int i = 1; i <= 2 * m; i++) {
for(int j = 1; j <= _min(n - 1, 2 * m); j++) S[i][j] = add(S[i - 1][j - 1], (LL)S[i - 1][j] * j % mod);
} for(int i = 1; i <= n; i++) {
d1[i][0] = d2[i][0] = 1;
for(int j = 1; j <= _min(n - 2, m); j++) d1[i][j] = (LL)d1[i][j - 1] * a[i] % mod;
for(int j = 1; j <= _min(n - 2, 2 * m); j++) d2[i][j] = (LL)d2[i][j - 1] * a[i] % mod;
} for(int i = 1; i <= n; i++) {
for(int j = 0; j <= _min(n - 2, m); j++) d1[i][j] = (LL)d1[i][j] * inv[j] % mod * add((LL)S[m][j] * jc[j] % mod, (LL)S[m][j + 1] * jc[j + 1] % mod) % mod;
for(int j = 0; j <= _min(n - 2, 2 * m); j++) d2[i][j] = (LL)d2[i][j] * inv[j] % mod * add((LL)S[2 * m][j] * jc[j] % mod, (LL)S[2 * m][j + 1] * jc[j + 1] % mod) % mod;
}
}
void solve(int l, int r, int d) {
if(l == r) {
for(int i = 0; i <= _min(n, 2 * m); i++) g[d][i] = d2[l][i], f[d][i] = d1[l][i];
return ;
} int mid = (l + r) / 2;
int L1 = _min((mid - l + 1) * m + m, n - 2), L2 = _min((r - mid) * m + m, n - 2), L;
for(L = 1; L <= (L1 + L2); L <<= 1);
solve(l, mid, d + 1);
for(int i = 0; i <= L1; i++) f[d][i] = f[d + 1][i];
for(int i = 0; i <= L1; i++) g[d][i] = g[d + 1][i];
solve(mid + 1, r, d + 1);
for(int i = 0; i < L; i++) R[i] = (R[i >> 1] >> 1) | ((i & 1) * (L >> 1));
for(int i = L1 + 1; i < L; i++) f[d][i] = g[d][i] = 0;
for(int i = L2 + 1; i < L; i++) f[d + 1][i] = g[d + 1][i] = 0;
NTT(f[d], L, 1), NTT(g[d], L, 1), NTT(f[d + 1], L, 1), NTT(g[d + 1], L, 1);
for(int i = 0; i < L; i++) g[d][i] = add((LL)f[d + 1][i] * g[d][i] % mod, (LL)f[d][i] * g[d + 1][i] % mod), f[d][i] = (LL)f[d][i] * f[d + 1][i] % mod;
NTT(f[d], L, -1), NTT(g[d], L, -1);
int uu = (r - l + 1) * m + m;
for(int i = _min(uu, n - 2) + 1; i < L; i++) g[d][i] = 0;
for(int i = _min(uu, n - 2) + 1; i < L; i++) f[d][i] = 0;
}
int main() {
n = read(), m = read(); int sum = 0;
for(int i = 1; i <= n; i++) a[i] = read(), sum = add(sum, a[i]);
pre(), solve(1, n, 0); int ans = 0, o = 1;
for(int i = n - 2; i >= 0; i--) {
ans = add(ans, (LL)g[0][i] * inv[n - 2 - i] % mod * o % mod);
o = (LL)o * sum % mod;
} for(int i = 1; i <= n; i++) ans = (LL)ans * a[i] % mod;
put((LL)ans * jc[n - 2] % mod);
return 0;
}