BZOJ 3451 点分治 + NTT

题意

传送门 BZOJ 3451 Tyvj1953 Normal

题解

v v v 取为分治中心时, u u u v v v 连通,则 u u u 贡献为 1 1 1 u , v u,v u,v v v v 取为分治中心时连通的条件为 u , v u,v u,v 路径上任意节点都未被选取为分治中心,不属于路径上的节点对于这个概率没有影响,则概率等价于 u , v u,v u,v 路径上的点中, v v v 被首先选取的概率,其等于 1 / ( d i s t ( u , v ) + 1 ) 1/\Big(dist(u,v) + 1\Big) 1/(dist(u,v)+1)

总贡献为 ∑ u , v 1 / ( d i s t ( u , v ) + 1 ) \sum_{u,v} 1/\Big(dist(u,v) + 1\Big) u,v1/(dist(u,v)+1),统计距离相等的点对数量,求解答案即可。

统计点对方法:进行标准的点分治,对树上各节点与重心的距离做卷积,再减去子树的贡献。

总时间复杂度 O ( n log ⁡ 2 n ) O(n\log^2n) O(nlog2n)

#include <bits/stdc++.h>
using namespace std;
using ll = long long;
constexpr ll MOD = 998244353, PRT = 3;
//constexpr ll MOD = 1004535809, PRT = 3;
ll qpow(ll x, ll n)
{
    ll res = 1;
    while (n > 0)
    {
        if (n & 1)
            res = res * x % MOD;
        x = x * x % MOD, n >>= 1;
    }
    return res;
}
vector<int> rev;
struct Poly : vector<ll>
{
    Poly() {}
    Poly(int n) : vector<ll>(n) {}
    Poly(const initializer_list<ll> &list) : vector<ll>(list) {}
    void fft(int n, bool inverse)
    {
        if ((int)rev.size() != n)
        {
            rev.resize(n);
            for (int i = 0; i < n; ++i)
                rev[i] = rev[i >> 1] >> 1 | (i & 1 ? n >> 1 : 0);
        }
        resize(n);
        for (int i = 0; i < n; ++i)
            if (i < rev[i])
                std::swap(at(i), at(rev[i]));

        for (int m = 1; m < n; m <<= 1)
        {
            int m2 = m << 1;
            ll _w = qpow(inverse ? qpow(PRT, MOD - 2) : PRT, (MOD - 1) / m2);
            for (int i = 0; i < n; i += m2)
                for (int w = 1, j = 0; j < m; ++j, w = w * _w % MOD)
                {
                    ll &x = at(i + j), &y = at(i + j + m), t = w * y % MOD;
                    y = x - t;
                    if (y < 0)
                        y += MOD;
                    x += t;
                    if (x >= MOD)
                        x -= MOD;
                }
        }
    }
    void dft(int n) { fft(n, 0); };
    void idft(int n)
    {
        fft(n, 1);
        for (int i = 0, inv = qpow(n, MOD - 2); i < n; ++i)
            at(i) = at(i) * inv % MOD;
    }
    Poly operator*(const Poly &p) const
    {
        auto a = *this, b = p;
        int k = 1, n = a.size() + b.size() - 1;
        while (k < n)
            k <<= 1;
        a.dft(k), b.dft(k);
        for (int i = 0; i < k; ++i)
            a[i] = a[i] * b[i] % MOD;
        a.idft(k);
        a.resize(n);
        return a;
    }
};
constexpr int MAXN = 3E4 + 5;
int N;
vector<int> G[MAXN];
int sz[MAXN], num[MAXN];
bool del[MAXN];

void get_rt(int v, int p, int n, int &rt, int &mx)
{
    sz[v] = 1;
    int t = 0;
    for (int u : G[v])
        if (!del[u] && u != p)
            get_rt(u, v, n, rt, mx), sz[v] += sz[u], t = max(t, sz[u]);
    t = max(t, n - sz[v]);
    if (mx == -1 || t < mx)
        mx = t, rt = v;
}

void get_sz(int v, int p)
{
    sz[v] = 1;
    for (int u : G[v])
        if (!del[u] && u != p)
            get_sz(u, v), sz[v] += sz[u];
}

void dfs(int v, int p, int d, vector<int> &ds)
{
    ds.push_back(d);
    for (int u : G[v])
        if (!del[u] && u != p)
            dfs(u, v, d + 1, ds);
}

void count(vector<int> &ds, int x)
{
    int dmax = 0;
    for (int d : ds)
        dmax = max(dmax, d);
    Poly f(dmax + 1);
    vector<int> c(dmax + 1);
    for (int d : ds)
        ++c[d];
    for (int i = 0; i <= dmax; ++i)
        f[i] = c[i];
    f = f * f;
    for (int i = 0; i < (int)f.size(); ++i)
        num[i] += x * f[i];
}

void solve(int v, int n)
{
    int rt = -1, mx = -1;
    get_rt(v, -1, n, rt, mx);
    get_sz(rt, -1);
    vector<int> ds;
    ds.push_back(0);
    del[rt] = 1;
    for (int u : G[rt])
        if (!del[u])
        {
            vector<int> tds;
            dfs(u, -1, 1, tds);
            count(tds, -1);
            ds.insert(ds.end(), tds.begin(), tds.end());
        }
    count(ds, 1);
    for (int u : G[rt])
        if (!del[u])
            solve(u, sz[u]);
}

int main()
{
    ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
    cin >> N;
    for (int i = 1; i < N; ++i)
    {
        int u, v;
        cin >> u >> v;
        G[u].push_back(v);
        G[v].push_back(u);
    }
    solve(0, N);
    long double res = 0;
    for (int d = 0; d < N; ++d)
        res += (long double)num[d] / (1 + d);
    cout << fixed << setprecision(4) << res << '\n';
    return 0;
}
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值