题意
给一颗n点无根树,令 D ( x , y ) D(x,y) D(x,y)表示编号在 [ x , y ] [x,y] [x,y]的点组成的树的直径,求 ∑ l = 1 n ∑ r = l + 1 n D ( l , r ) \sum_{l=1}^n\sum_{r=l+1}^nD(l,r) ∑l=1n∑r=l+1nD(l,r)
解题思路
考虑分治,可以把当前[l,r]的问题分解成[l,mid]的问题和[mid+1,r]的子问题。
现在考虑如何快速所有左端点在
[
l
,
m
i
d
]
[l,mid]
[l,mid],右端点在
[
m
i
d
+
1
,
r
]
[mid+1,r]
[mid+1,r]的直径和。
对于
i
∈
[
l
,
m
i
d
]
i\in[l,mid]
i∈[l,mid],
[
m
i
d
+
1
,
r
]
[mid+1,r]
[mid+1,r]的点j有三种类型:
- D ( i , j ) = D ( i , m i d ) D(i,j)=D(i,mid) D(i,j)=D(i,mid)
- D ( i , j ) = ( D ( i , m i d ) + D ( m i d + 1 , j ) ) / 2 + d i s ( c 1 , c 2 ) D(i,j)=(D(i,mid)+D(mid+1,j))/2+dis(c1,c2) D(i,j)=(D(i,mid)+D(mid+1,j))/2+dis(c1,c2), (其中c1是(i,mid)这部分点集直径的中心,c2是(mid+1,j)这部分点集直径的中心
- D ( i , j ) = D ( m i d + 1 , j ) D(i,j)=D(mid+1,j) D(i,j)=D(mid+1,j)
设分为三段
[
m
i
d
1
+
,
p
1
)
,
[
p
1
,
p
2
]
,
(
p
2
,
r
]
[mid1+,p1), [p1,p2], (p2, r]
[mid1+,p1),[p1,p2],(p2,r],随着i变小,p1和p2都在增大,所以可以双指针扫出p1和p2
所以在每一层,枚举左端点,快速维护右边的部分。右边部分,1和3类型对答案的贡献容易求,2类型也就是
[
p
1
,
p
2
]
[p1,p2]
[p1,p2]这部分,可以把它拆成
D
(
i
,
m
i
d
)
,
D
(
m
i
d
+
1
,
j
)
和
d
i
s
(
c
1
,
c
2
)
D(i,mid), D(mid+1,j)和dis(c1,c2)
D(i,mid),D(mid+1,j)和dis(c1,c2)来求,
D
(
i
,
m
i
d
)
,
D
(
m
i
d
+
1
,
j
)
D(i,mid), D(mid+1,j)
D(i,mid),D(mid+1,j)容易得到,而
d
i
s
(
c
1
,
c
2
)
dis(c1,c2)
dis(c1,c2)的总和可以看成是树上加了若干个c2,求一个点c1到它们的距离和,这个可以用数据结构维护,这里用树链剖分+线段树维护,时间复杂度理论上是
n
l
o
g
3
n
nlog^3n
nlog3n,其中分治1个log,树链剖分+线段树2个log,如果用动态点分治维护可以降低到
n
l
o
g
2
n
nlog^2n
nlog2n(还需要改成O(1)查询LCA)
#include<bits/stdc++.h>
#define ll long long
#define pb push_back
#define lowbit(x) ((x)&(-(x)))
#define mid ((l+r)>>1)
#define lson rt<<1, l, mid
#define rson rt<<1|1, mid+1, r
#define fors(i, a, b) for(int i = (a); i < (b); ++i)
using namespace std;
const int maxn = 2e5 + 5;
int sz[maxn], top[maxn], son[maxn], dep[maxn], fa[maxn];
int dfn[maxn], id[maxn], idx = 0;
vector<int> g[maxn];
void dfs1(int u){//sz, fa, dep, son
sz[u] = 1;
for(int v:g[u]){
if(v == fa[u]) continue;
dep[v] = dep[u]+1; fa[v] = u;
dfs1(v); sz[u] += sz[v];
if(sz[son[u]] < sz[v]) son[u] = v;
}return;
}
void dfs2(int u, int tp){//dfn,id,top
dfn[id[u] = ++idx] = u;
top[u] = tp; if(son[u]) dfs2(son[u], tp);
for(int v: g[u]){
if(v == fa[u]||v==son[u]) continue;
dfs2(v, v);
}
}
inline int lca(int u, int v){
while (top[u] != top[v]) {
if (dep[top[u]] > dep[top[v]]) u = fa[top[u]];
else v = fa[top[v]];
}
return dep[u] < dep[v] ? u : v;
}
int dist(int u, int v){return dep[u]+dep[v]-2*dep[lca(u,v)]; }
struct Dia{
int u, v, d;
Dia(int _u=0, int _v=0, int _d=-1) : u(_u), v(_v), d(_d) {}
bool operator == (const Dia & a) const {
return (u == a.u && v == a.v||u==a.v && v == a.u);
}
bool operator != (const Dia & a) const {return !(*this == a);}
Dia operator + (const Dia & a) const {
if (a.d == -1) return *this;
if (d == -1) return a;
Dia c = (d < a.d ? a : *this);
for (auto x : {u, v}) {
for (auto y : {a.u, a.v}) {
int d = dist(x, y);
if (d > c.d) c = Dia(x, y, d);
}
}
return c;
}
};
int get_center(Dia x){
int u = x.u,v = x.v; if(u == v) return u;
if(dep[u] == dep[v]) return lca(u,v);
if(dep[u] < dep[v]) swap(u, v);
int cur = u;
while( 2*(dep[u]-dep[fa[top[cur]]] ) <= x.d ) cur = fa[top[cur]];
int l = id[top[cur]], r = id[cur];
int res = -1;
while(l <= r){
int p = dfn[mid];
if((dep[u]-dep[p])*2 <= x.d) res = p, r = mid-1;
else l = mid+1;
}
assert(res != -1);
return res;
}
int n;
void add(int x, int y){g[x].pb(y); g[y].pb(x);}
namespace tree{
int lz[maxn<<2];ll sum[maxn<<2];
void down(int rt, int l, int r){
if(lz[rt]){
lz[rt<<1] += lz[rt];
lz[rt<<1|1] += lz[rt];
sum[rt<<1] += lz[rt] * (ll)(mid-l+1);
sum[rt<<1|1] += lz[rt]*(ll)(r-mid);
lz[rt] = 0;
}
}
void add(int rt, int l, int r, int L, int R, int x){
if(L <= l && r <= R) {
lz[rt] += x; sum[rt] += x * (r-l+1); return;
}down(rt, l, r);
if(L <= mid) add(lson , L, R, x);
if(R > mid) add(rson, L, R, x);
sum[rt] = sum[rt<<1] + sum[rt<<1|1];
return;
}
ll qry(int rt, int l, int r, int L, int R){
if(L <= l && r <= R) return sum[rt];
down(rt,l,r); ll res = 0;
if(L <= mid) res += qry(lson, L, R);
if(R > mid) res += qry(rson, L, R); return res;
}
}
Dia e[maxn];
ll sumd[maxn];
ll sum_dep = 0, sum_sz = 0;
void del(int x){
sum_sz--; sum_dep -= dep[x];
while(x) tree::add(1,1,idx,id[top[x]], id[x], -2), x = fa[top[x]];
}
void add(int x){
sum_sz++; sum_dep += dep[x];
while(x) tree::add(1,1,idx,id[top[x]], id[x], 2), x = fa[top[x]];
}
ll qry(int x){
ll res = sum_sz*dep[x] + sum_dep;
while(x) res -= (tree::qry(1,1,idx,id[top[x]],id[x])), x = fa[top[x]];
return res;
}
int cen[maxn];
ll sol(int l, int r){
if(l == r) return 0;
ll ans = 0;
ans += sol(l, mid); ans += sol(mid+1, r);
assert(sum_dep == 0 && sum_sz == 0);
e[mid] = Dia(mid,mid,0);
for(int i = mid-1; i >= l; --i) e[i] = Dia(i,i,0)+e[i+1];
e[mid+1] = Dia(mid+1, mid+1, 0);
sumd[mid] = sumd[mid+1] = 0;
for(int i = mid+2; i <= r; ++i) e[i] = e[i-1] + Dia(i,i,0), sumd[i] = sumd[i-1]+e[i].d;
int p1 = mid+1, p2 = mid;
for(int i = mid; i >= l; --i){
int c1 = get_center(e[i]);
while(p2+1 <= r){
int c2 = get_center(e[p2+1]);
cen[p2+1] = c2;
if( (e[i]+e[p2+1]).d*2 == e[i].d + e[p2+1].d + 2*dist(c1, c2) ) add(c2), ++p2;
else break;
}
while(p1 <= r && (e[i]+e[p1]).d == e[i].d) {
int c = get_center(e[p1]);
del(c); p1++;
}
ans += (ll)(p1-mid-1) * e[i].d;
ans += sumd[r]-sumd[p2];
ans += (p2-p1+1) * (ll)(e[i].d/2) + (sumd[p2]-sumd[p1-1])/2;
ans += qry(c1);
}
while(p1 <= p2) del(cen[p1]), p1++;
return ans;
}
int main()
{
scanf("%d", &n); fors(i,1,n) {
int u,v;
scanf("%d%d",&u,&v);
add(n+i,u); add(n+i, v);
}
dep[1] = 1;
dfs1(1); dfs2(1,1);
cout<<sol(1,n)/2<<endl;
return 0;
}