题目描述
在一个 sss 个点的图中,存在 s−ns-ns−n 条边,使图中形成了 nnn 个连通块,第 iii 个连通块中有 aia_iai 个点。
现在我们需要再连接 n−1n-1n−1 条边,使该图变成一棵树。对一种连边方案,设原图中第 iii 个连通块连出了 did_idi 条边,那么这棵树 TTT 的价值为:
你的任务是求出所有可能的生成树的价值之和,对 998244353998244353998244353 取模。
输入
输入的第一行包含两个整数 n,mn,mn,m,意义见题目描述。
接下来一行有 nnn 个整数,第 iii 个整数表示 aia_iai (1≤ai<998244353)(1\le a_i< 998244353)(1≤ai<998244353)。
- 你可以由 aia_iai 计算出图的总点数 sss,所以在输入中不再给出 sss 的值。
输出
输出包含一行一个整数,表示答案。
样例输入
3 1
2 3 4
样例输出
1728
提示
本题共有 202020 个测试点,每个测试点 555 分。
-
20%20\%20% 的数据中,n≤500n\le500n≤500。
-
另外 20%20\%20% 的数据中,n≤3000n \le 3000n≤3000。
-
另外 10%10\%10% 的数据中,n≤10010,m=1n \le 10010, m = 1n≤10010,m=1。
-
另外 10%10\%10%的数据中,n≤10015,m=2n \le 10015,m = 2n≤10015,m=2。
-
另外 20%20\%20% 的数据中,所有 aia_iai 相等。
-
100%100\%100% 的数据中,n≤3×104,m≤30n \le 3\times 10^4,m \le 30n≤3×104,m≤30。
其中,每一个部分分的测试点均有一定梯度。
#include <bits/stdc++.h>
template <class T>
inline void read(T &x)
{
static char ch;
while (!isdigit(ch = getchar()));
x = ch - '0';
while (isdigit(ch = getchar()))
x = x * 10 + ch - '0';
}
const int mod = 998244353;
inline int qpow(int x, int y)
{
int res = 1;
for (; y; y >>= 1, x = 1LL * x * x % mod)
if (y & 1)
res = 1LL * res * x % mod;
return res;
}
inline void add(int &x, const int &y)
{
x += y;
if (x >= mod)
x -= mod;
}
inline void dec(int &x, const int &y)
{
x -= y;
if (x < 0)
x += mod;
}
typedef std::vector<int> vi;
typedef std::pair<vi, vi> pvi;
#define mp(x, y) std::make_pair(x, y)
const int MaxN = 2e5 + 5;
const int INF = 0x3f3f3f3f;
int fac[MaxN], fac_inv[MaxN], pwm[MaxN], ind[MaxN];
inline void fac_init(int n)
{
ind[1] = 1;
for (int i = 2; i <= n; ++i)
ind[i] = 1LL * ind[mod % i] * (mod - mod / i) % mod;
fac[0] = 1;
for (int i = 1; i <= n; ++i)
fac[i] = 1LL * fac[i - 1] * i % mod;
fac_inv[n] = qpow(fac[n], mod - 2);
for (int i = n - 1; i >= 0; --i)
fac_inv[i] = 1LL * fac_inv[i + 1] * (i + 1) % mod;
}
namespace polynomial
{
int P, L;
int rev[MaxN];
inline void DFT_init(int n)
{
P = 0, L = 1;
while (L < n)
L <<= 1, ++P;
for (int i = 1; i < L; ++i)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (P - 1));
}
inline void DFT(vi &a, int n, int opt)
{
for (int i = 0; i < n; ++i)
if (i < rev[i])
std::swap(a[i], a[rev[i]]);
int g = opt == 1 ? 3 : (mod + 1) / 3;
for (int k = 1; k < n; k <<= 1)
{
int omega = qpow(g, (mod - 1) / (k << 1));
for (int i = 0; i < n; i += k << 1)
{
int x = 1;
for (int j = 0; j < k; ++j)
{
int u = a[i + j];
int v = 1LL * a[i + j + k] * x % mod;
add(a[i + j] = u, v);
dec(a[i + j + k] = u, v);
x = 1LL * x * omega % mod;
}
}
}
if (opt == -1)
{
int inv = ind[n];
for (int i = 0; i < n; ++i)
a[i] = 1LL * a[i] * inv % mod;
}
}
inline vi plus(vi a, vi b)
{
int sze = std::max(a.size(), b.size());
a.resize(sze), b.resize(sze);
for (int i = 0; i < sze; ++i)
add(a[i], b[i]);
return a;
}
inline vi mul(vi a, vi b, int lim = INF)
{
int sze = a.size() + b.size() - 1;
DFT_init(sze), a.resize(L, 0), b.resize(L, 0);
vi c(L);
DFT(a, L, 1), DFT(b, L, 1);
for (int i = 0; i < L; ++i)
c[i] = 1LL * a[i] * b[i] % mod;
DFT(c, L, -1);
return c.resize(std::min(sze, lim)), c;
}
inline vi inverse(vi a)
{
int n = a.size(), m = 1;
vi b(1, qpow(a[0], mod - 2)), ta;
while (m < n)
{
m <<= 1;
DFT_init(m << 1);
b.resize(L, 0);
(ta = a).resize(m);
ta.resize(L, 0);
DFT(b, L, 1), DFT(ta, L, 1);
for (int i = 0; i < L; ++i)
b[i] = 1LL * b[i] * (mod + 2 - 1LL * ta[i] * b[i] % mod) % mod;
DFT(b, L, -1);
b.resize(m, 0);
}
return b.resize(n), b;
}
inline vi derivative(vi a)
{
vi res(0);
for (int i = 1, lim = a.size(); i < lim; ++i)
res.push_back(1LL * i * a[i] % mod);
return res;
}
inline vi anti_derivative(vi a)
{
vi res(1, 0);
for (int i = 0, lim = a.size(); i < lim; ++i)
res.push_back(1LL * a[i] * ind[i + 1] % mod);
return res;
}
inline vi ln(vi a)
{
return anti_derivative(mul(derivative(a), inverse(a), a.size() - 1));
}
inline vi exp(vi a)
{
int n = a.size(), m = 1;
vi b(1, 1), ta;
while (m < n)
{
m <<= 1;
b.resize(m, 0);
vi ln_b = ln(b);
(ta = a).resize(m);
add(ta[0], 1);
for (int i = 0; i < m; ++i)
dec(ta[i], ln_b[i]);
b = mul(b, ta, m);
}
return b.resize(n), b;
}
}
vi sum;
int n, m;
int a[MaxN];
inline pvi solve(int l, int r)
{
using namespace polynomial;
if (l == r)
{
vi t(1, 1); t.push_back(mod - a[l]);
return mp(vi(1, 1), t);
}
int mid = (l + r) >> 1;
pvi lef = solve(l, mid), rit = solve(mid + 1, r);
return mp(plus(mul(lef.first, rit.second), mul(rit.first, lef.second)), mul(lef.second, rit.second));
}
inline vi get_sum(vi a)
{
vi res(0); int n = a.size();
for (int i = 0; i < n; ++i)
res.push_back(1LL * a[i] * sum[i] % mod);
return res;
}
int main()
{
read(n), read(m), fac_init(MaxN - 1);
for (int i = 0; i <= (n << 1); ++i)
pwm[i] = qpow(i, m);
int prod = 1;
for (int i = 1; i <= n; ++i)
{
read(a[i]);
prod = 1LL * prod * a[i] % mod;
}
if (n == 1)
return puts(m ? "0" : "1"), 0;
using namespace polynomial;
pvi t = solve(1, n);
sum = mul(t.first, inverse(t.second), n - 1);
vi A(0), B(0);
for (int i = 0; i < n - 1; ++i)
{
A.push_back(1LL * pwm[i + 1] * fac_inv[i] % mod);
B.push_back(1LL * pwm[i + 1] * pwm[i + 1] % mod * fac_inv[i] % mod);
}
B = get_sum(mul(B, inverse(A), n - 1));
A = exp(get_sum(ln(A)));
int res = mul(A, B)[n - 2];
std::cout << 1LL * fac[n - 2] * prod % mod * res % mod << '\n';
return 0;
}