1458F - Range Diameter Sum(分治+数据结构)

题目链接

题意

给一颗n点无根树,令 D ( x , y ) D(x,y) D(x,y)表示编号在 [ x , y ] [x,y] [x,y]的点组成的树的直径,求 ∑ l = 1 n ∑ r = l + 1 n D ( l , r ) \sum_{l=1}^n\sum_{r=l+1}^nD(l,r) l=1nr=l+1nD(l,r)

解题思路

考虑分治,可以把当前[l,r]的问题分解成[l,mid]的问题和[mid+1,r]的子问题。
现在考虑如何快速所有左端点在 [ l , m i d ] [l,mid] [l,mid],右端点在 [ m i d + 1 , r ] [mid+1,r] [mid+1,r]的直径和。
对于 i ∈ [ l , m i d ] i\in[l,mid] i[l,mid], [ m i d + 1 , r ] [mid+1,r] [mid+1,r]的点j有三种类型:

  1. D ( i , j ) = D ( i , m i d ) D(i,j)=D(i,mid) D(i,j)=D(i,mid)
  2. D ( i , j ) = ( D ( i , m i d ) + D ( m i d + 1 , j ) ) / 2 + d i s ( c 1 , c 2 ) D(i,j)=(D(i,mid)+D(mid+1,j))/2+dis(c1,c2) D(i,j)=(D(i,mid)+D(mid+1,j))/2+dis(c1,c2), (其中c1是(i,mid)这部分点集直径的中心,c2是(mid+1,j)这部分点集直径的中心
  3. D ( i , j ) = D ( m i d + 1 , j ) D(i,j)=D(mid+1,j) D(i,j)=D(mid+1,j)

设分为三段 [ m i d 1 + , p 1 ) , [ p 1 , p 2 ] , ( p 2 , r ] [mid1+,p1), [p1,p2], (p2, r] [mid1+,p1),[p1,p2],(p2,r],随着i变小,p1和p2都在增大,所以可以双指针扫出p1和p2
所以在每一层,枚举左端点,快速维护右边的部分。右边部分,1和3类型对答案的贡献容易求,2类型也就是 [ p 1 , p 2 ] [p1,p2] [p1,p2]这部分,可以把它拆成 D ( i , m i d ) , D ( m i d + 1 , j ) 和 d i s ( c 1 , c 2 ) D(i,mid), D(mid+1,j)和dis(c1,c2) D(i,mid),D(mid+1,j)dis(c1,c2)来求, D ( i , m i d ) , D ( m i d + 1 , j ) D(i,mid), D(mid+1,j) D(i,mid),D(mid+1,j)容易得到,而 d i s ( c 1 , c 2 ) dis(c1,c2) dis(c1,c2)的总和可以看成是树上加了若干个c2,求一个点c1到它们的距离和,这个可以用数据结构维护,这里用树链剖分+线段树维护,时间复杂度理论上是 n l o g 3 n nlog^3n nlog3n,其中分治1个log,树链剖分+线段树2个log,如果用动态点分治维护可以降低到 n l o g 2 n nlog^2n nlog2n(还需要改成O(1)查询LCA)

#include<bits/stdc++.h>
#define ll long long
#define pb push_back
#define lowbit(x) ((x)&(-(x)))
#define mid ((l+r)>>1)
#define lson rt<<1, l, mid
#define rson rt<<1|1, mid+1, r
#define fors(i, a, b) for(int i = (a); i < (b); ++i)
using namespace std;
const int maxn = 2e5 + 5;
int sz[maxn],  top[maxn], son[maxn], dep[maxn], fa[maxn];
int dfn[maxn], id[maxn], idx = 0;
vector<int> g[maxn];
void dfs1(int u){//sz, fa, dep, son
    sz[u] = 1;
    for(int v:g[u]){
        if(v == fa[u]) continue;
        dep[v] = dep[u]+1; fa[v] = u;
        dfs1(v); sz[u] += sz[v];
        if(sz[son[u]] < sz[v]) son[u] = v;
    }return;
}
void dfs2(int u, int tp){//dfn,id,top
    dfn[id[u] = ++idx] = u;
    top[u] = tp; if(son[u]) dfs2(son[u], tp);
    for(int v: g[u]){
        if(v == fa[u]||v==son[u]) continue;
        dfs2(v, v);
    }
}
inline int lca(int u, int v){
    while (top[u] != top[v]) {
        if (dep[top[u]] > dep[top[v]]) u = fa[top[u]];
        else v = fa[top[v]];
    }
    return dep[u] < dep[v] ? u : v;
}
int dist(int u, int v){return dep[u]+dep[v]-2*dep[lca(u,v)]; }
struct Dia{
    int u, v, d;
    Dia(int _u=0, int _v=0, int _d=-1) : u(_u), v(_v), d(_d) {}
    bool operator == (const Dia & a) const {
        return (u == a.u && v == a.v||u==a.v && v == a.u);
    }
    bool operator != (const Dia & a) const {return !(*this == a);}
    Dia operator + (const Dia & a) const {
        if (a.d == -1) return *this;
        if (d == -1) return a;
        Dia c = (d < a.d ? a : *this);
        for (auto x : {u, v}) {
            for (auto y : {a.u, a.v}) {
                int d = dist(x, y);
                if (d > c.d) c = Dia(x, y, d);
            }
        }
        return c;
    }
};
int get_center(Dia x){
    int u = x.u,v = x.v; if(u == v) return u;
    if(dep[u] == dep[v]) return lca(u,v);
    if(dep[u] < dep[v]) swap(u, v);
    int cur = u;
    while( 2*(dep[u]-dep[fa[top[cur]]] ) <= x.d ) cur = fa[top[cur]];
    int l = id[top[cur]], r = id[cur];
    int res = -1;
    while(l <= r){
        int p = dfn[mid];
        if((dep[u]-dep[p])*2 <= x.d) res = p, r = mid-1;
        else l = mid+1;
    }
    assert(res != -1);
    return res;
}
int n;
void add(int x, int y){g[x].pb(y); g[y].pb(x);}
namespace tree{
    int lz[maxn<<2];ll sum[maxn<<2];
    void down(int rt, int l, int r){
        if(lz[rt]){
            lz[rt<<1] += lz[rt];
            lz[rt<<1|1] += lz[rt];
            sum[rt<<1] += lz[rt] * (ll)(mid-l+1);
            sum[rt<<1|1] += lz[rt]*(ll)(r-mid);
            lz[rt] = 0;
        }
    }
    void add(int rt, int l, int r, int L, int R, int x){
        if(L <= l && r <= R) {
            lz[rt] += x; sum[rt] += x * (r-l+1); return;
        }down(rt, l, r);
        if(L <= mid) add(lson , L, R, x);
        if(R > mid) add(rson, L, R, x);
        sum[rt] = sum[rt<<1] + sum[rt<<1|1];
        return;
    }
    ll qry(int rt, int l, int r, int L, int R){
        if(L <= l && r <= R) return sum[rt];
        down(rt,l,r); ll res = 0;
        if(L <= mid) res += qry(lson, L, R);
        if(R > mid) res += qry(rson, L, R); return res;
    }
}
Dia e[maxn];
ll sumd[maxn];
ll sum_dep = 0, sum_sz = 0;
void del(int x){
    sum_sz--; sum_dep -= dep[x];
    while(x) tree::add(1,1,idx,id[top[x]], id[x], -2), x = fa[top[x]];
}
void add(int x){
    sum_sz++; sum_dep += dep[x];
    while(x) tree::add(1,1,idx,id[top[x]], id[x], 2), x = fa[top[x]];
}
ll qry(int x){
    ll res = sum_sz*dep[x] + sum_dep;
    while(x) res -= (tree::qry(1,1,idx,id[top[x]],id[x])), x = fa[top[x]];
    return res;
}
int cen[maxn];
ll sol(int l, int r){
    if(l == r) return 0;
    ll ans = 0;
    ans += sol(l, mid); ans += sol(mid+1, r);
    assert(sum_dep == 0 && sum_sz == 0);
    e[mid] = Dia(mid,mid,0);
    for(int i = mid-1; i >= l; --i) e[i] = Dia(i,i,0)+e[i+1];
    e[mid+1] = Dia(mid+1, mid+1, 0);
    sumd[mid] = sumd[mid+1] = 0;
    for(int i = mid+2; i <= r; ++i) e[i] = e[i-1] + Dia(i,i,0), sumd[i] = sumd[i-1]+e[i].d;
    int p1 = mid+1, p2 = mid;

    for(int i = mid; i >= l; --i){
        int c1 = get_center(e[i]);
        while(p2+1 <= r){
            int c2 = get_center(e[p2+1]);
            cen[p2+1] = c2;
            if( (e[i]+e[p2+1]).d*2 == e[i].d + e[p2+1].d + 2*dist(c1, c2) ) add(c2), ++p2;
            else break;
        }
        while(p1 <= r && (e[i]+e[p1]).d == e[i].d) {
            int c = get_center(e[p1]);
            del(c); p1++;
        }
        ans += (ll)(p1-mid-1) * e[i].d;
        ans += sumd[r]-sumd[p2];
        ans += (p2-p1+1) * (ll)(e[i].d/2) + (sumd[p2]-sumd[p1-1])/2;
        ans += qry(c1);
    }
    while(p1 <= p2) del(cen[p1]), p1++;
    return ans;
}
int main()
{
    scanf("%d", &n); fors(i,1,n) {
        int u,v;
        scanf("%d%d",&u,&v);
        add(n+i,u); add(n+i, v);
    }
    dep[1] = 1;
    dfs1(1); dfs2(1,1);
    cout<<sol(1,n)/2<<endl;
	return 0;
}

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值