1901: #2320. 「清华集训 2017」生成树计数

题目描述

在一个 sss 个点的图中,存在 s−ns-ns−n 条边,使图中形成了 nnn 个连通块,第 iii 个连通块中有 aia_iai 个点。

现在我们需要再连接 n−1n-1n−1 条边,使该图变成一棵树。对一种连边方案,设原图中第 iii 个连通块连出了 did_idi 条边,那么这棵树 TTT 的价值为:

你的任务是求出所有可能的生成树的价值之和,对 998244353998244353998244353 取模。

输入

输入的第一行包含两个整数 n,mn,mn,m,意义见题目描述。

接下来一行有 nnn 个整数,第 iii 个整数表示 aia_iai (1≤ai<998244353)(1\le a_i< 998244353)(1≤ai<998244353)。

  • 你可以由 aia_iai 计算出图的总点数 sss,所以在输入中不再给出 sss 的值。

输出

输出包含一行一个整数,表示答案。

样例输入 

3 1
2 3 4

样例输出 

1728

提示

本题共有 202020 个测试点,每个测试点 555 分。

  • 20%20\%20% 的数据中,n≤500n\le500n≤500。

  • 另外 20%20\%20% 的数据中,n≤3000n \le 3000n≤3000。

  • 另外 10%10\%10% 的数据中,n≤10010,m=1n \le 10010, m = 1n≤10010,m=1。

  • 另外 10%10\%10%的数据中,n≤10015,m=2n \le 10015,m = 2n≤10015,m=2。

  • 另外 20%20\%20% 的数据中,所有 aia_iai 相等。

  • 100%100\%100% 的数据中,n≤3×104,m≤30n \le 3\times 10^4,m \le 30n≤3×104,m≤30。

其中,每一个部分分的测试点均有一定梯度。

#include <bits/stdc++.h>
 
template <class T>
inline void read(T &x)
{
    static char ch; 
    while (!isdigit(ch = getchar()));
    x = ch - '0'; 
    while (isdigit(ch = getchar()))
        x = x * 10 + ch - '0'; 
}
 
const int mod = 998244353; 
 
inline int qpow(int x, int y)
{
    int res = 1; 
    for (; y; y >>= 1, x = 1LL * x * x % mod)
        if (y & 1)
            res = 1LL * res * x % mod; 
    return res; 
}
 
inline void add(int &x, const int &y)
{
    x += y; 
    if (x >= mod)
        x -= mod; 
}
 
inline void dec(int &x, const int &y)
{
    x -= y;
    if (x < 0)
        x += mod; 
}
 
typedef std::vector<int> vi; 
typedef std::pair<vi, vi> pvi; 
#define mp(x, y) std::make_pair(x, y)
 
const int MaxN = 2e5 + 5; 
const int INF = 0x3f3f3f3f; 
 
int fac[MaxN], fac_inv[MaxN], pwm[MaxN], ind[MaxN]; 
 
inline void fac_init(int n)
{
    ind[1] = 1; 
    for (int i = 2; i <= n; ++i)
        ind[i] = 1LL * ind[mod % i] * (mod - mod / i) % mod; 
 
    fac[0] = 1; 
    for (int i = 1; i <= n; ++i)
        fac[i] = 1LL * fac[i - 1] * i % mod; 
 
    fac_inv[n] = qpow(fac[n], mod - 2); 
    for (int i = n - 1; i >= 0; --i)
        fac_inv[i] = 1LL * fac_inv[i + 1] * (i + 1) % mod; 
}
 
namespace polynomial
{
    int P, L; 
    int rev[MaxN]; 
 
    inline void DFT_init(int n)
    {
        P = 0, L = 1; 
        while (L < n)
            L <<= 1, ++P; 
        for (int i = 1; i < L; ++i)
            rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (P - 1)); 
    }
     
    inline void DFT(vi &a, int n, int opt)
    {
        for (int i = 0; i < n; ++i)
            if (i < rev[i])
                std::swap(a[i], a[rev[i]]);
 
        int g = opt == 1 ? 3 : (mod + 1) / 3; 
        for (int k = 1; k < n; k <<= 1)
        {
            int omega = qpow(g, (mod - 1) / (k << 1)); 
            for (int i = 0; i < n; i += k << 1)
            {
                int x = 1; 
                for (int j = 0; j < k; ++j)
                {
                    int u = a[i + j]; 
                    int v = 1LL * a[i + j + k] * x % mod; 
                    add(a[i + j] = u, v); 
                    dec(a[i + j + k] = u, v); 
                    x = 1LL * x * omega % mod; 
                }
            }
        }
        if (opt == -1)
        {
            int inv = ind[n]; 
            for (int i = 0; i < n; ++i)
                a[i] = 1LL * a[i] * inv % mod; 
        }
    }
 
    inline vi plus(vi a, vi b)
    {
        int sze = std::max(a.size(), b.size()); 
        a.resize(sze), b.resize(sze); 
 
        for (int i = 0; i < sze; ++i)
            add(a[i], b[i]); 
        return a; 
    }
    inline vi mul(vi a, vi b, int lim = INF)
    {
        int sze = a.size() + b.size() - 1; 
        DFT_init(sze), a.resize(L, 0), b.resize(L, 0); 
 
        vi c(L); 
        DFT(a, L, 1), DFT(b, L, 1); 
        for (int i = 0; i < L; ++i)
            c[i] = 1LL * a[i] * b[i] % mod; 
        DFT(c, L, -1);
 
        return c.resize(std::min(sze, lim)), c; 
    }
    inline vi inverse(vi a)
    {
        int n = a.size(), m = 1; 
        vi b(1, qpow(a[0], mod - 2)), ta; 
        while (m < n)
        {
            m <<= 1; 
            DFT_init(m << 1); 
 
            b.resize(L, 0); 
            (ta = a).resize(m); 
            ta.resize(L, 0); 
 
            DFT(b, L, 1), DFT(ta, L, 1); 
            for (int i = 0; i < L; ++i)
                b[i] = 1LL * b[i] * (mod + 2 - 1LL * ta[i] * b[i] % mod) % mod; 
            DFT(b, L, -1); 
 
            b.resize(m, 0); 
        }
        return b.resize(n), b; 
    }
    inline vi derivative(vi a)
    {
        vi res(0); 
        for (int i = 1, lim = a.size(); i < lim; ++i)
            res.push_back(1LL * i * a[i] % mod); 
        return res; 
    }
    inline vi anti_derivative(vi a)
    {
        vi res(1, 0); 
        for (int i = 0, lim = a.size(); i < lim; ++i)
            res.push_back(1LL * a[i] * ind[i + 1] % mod); 
        return res; 
    }
    inline vi ln(vi a)
    {
        return anti_derivative(mul(derivative(a), inverse(a), a.size() - 1)); 
    }
    inline vi exp(vi a)
    {
        int n = a.size(), m = 1; 
        vi b(1, 1), ta; 
        while (m < n)
        {
            m <<= 1; 
 
            b.resize(m, 0); 
            vi ln_b = ln(b); 
 
            (ta = a).resize(m); 
            add(ta[0], 1); 
            for (int i = 0; i < m; ++i)
                dec(ta[i], ln_b[i]); 
            b = mul(b, ta, m); 
        }
        return b.resize(n), b; 
    }
}
 
vi sum; 
int n, m; 
int a[MaxN]; 
 
inline pvi solve(int l, int r)
{
    using namespace polynomial; 
    if (l == r)
    {
        vi t(1, 1); t.push_back(mod - a[l]); 
        return mp(vi(1, 1), t); 
    }
    int mid = (l + r) >> 1; 
    pvi lef = solve(l, mid), rit = solve(mid + 1, r); 
    return mp(plus(mul(lef.first, rit.second), mul(rit.first, lef.second)), mul(lef.second, rit.second)); 
}
 
inline vi get_sum(vi a)
{
    vi res(0); int n = a.size(); 
    for (int i = 0; i < n; ++i)
        res.push_back(1LL * a[i] * sum[i] % mod); 
    return res; 
}
 
int main()
{
    read(n), read(m), fac_init(MaxN - 1); 
    for (int i = 0; i <= (n << 1); ++i)
        pwm[i] = qpow(i, m); 
 
    int prod = 1; 
    for (int i = 1; i <= n; ++i)
    {
        read(a[i]);
        prod = 1LL * prod * a[i] % mod; 
    }
 
    if (n == 1)
        return puts(m ? "0" : "1"), 0; 
 
    using namespace polynomial; 
 
    pvi t = solve(1, n); 
    sum = mul(t.first, inverse(t.second), n - 1); 
 
    vi A(0), B(0); 
    for (int i = 0; i < n - 1; ++i)
    {
        A.push_back(1LL * pwm[i + 1] * fac_inv[i] % mod); 
        B.push_back(1LL * pwm[i + 1] * pwm[i + 1] % mod * fac_inv[i] % mod); 
    }
    B = get_sum(mul(B, inverse(A), n - 1)); 
    A = exp(get_sum(ln(A))); 
 
    int res = mul(A, B)[n - 2]; 
    std::cout << 1LL * fac[n - 2] * prod % mod * res % mod << '\n'; 
 
    return 0; 
}

  • 10
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值