看这样一个题(dsu on the tree):
给你一棵树,每个节点有一种颜色,问你每个子树x的颜色数最多的那种颜色,如果颜色数相同,那么种类数相加。
考虑最暴力的暴力,对于每个点遍历它的子树,统计答案,然后再撤销。但是这样太傻了,每个点显然可以继承一个儿子的信息,我们选择继承它的重儿子的信息,只 dfs 轻儿子。这样对于每个点,会被 dfs 它到根之间轻边数量次。所以复杂度是
O
(
n
log
n
)
O(n\log n)
O(nlogn)。
如果需要维护的是关于深度的信息呢?我们引入长链剖分。长链剖分,类似于重链剖分,我们定义每个点的 len 为从这个点出发向下的最长链的长度,把每个点的“长儿子”定义为所有儿子里 len 最长的点。
考虑维护深度信息,发现这时候我们继承重儿子显得很浪费,我们选择继承长儿子,然后合并短儿子的深度。发现每条长链只会在链顶被遍历一遍,而长链互不相交,因此复杂度是优秀的 O ( n ) O(n) O(n)。
长链剖分还有一个应用是 O ( 1 ) O(1) O(1) 求 k k k 级祖先,在这里就不啰嗦了。
例题:[WC2010]重建计划
二分答案以后,就是找边数在
[
L
,
U
]
[L,U]
[L,U] 的最长链。考虑暴力的 dp,设
f
[
i
]
[
j
]
f[i][j]
f[i][j] 表示
i
i
i 的子树中深度为
j
j
j 的点与
i
i
i 的最长距离。这是以深度为下标的信息,我们尝试用长链剖分去优化。一条链的 dfs 序是连续的一段,用
f
[
d
f
n
i
+
j
]
f[dfn_i+j]
f[dfni+j] 表示
f
[
i
]
[
j
]
f[i][j]
f[i][j],我们发现继承长儿子信息的时候深度恰好“后移”了一位。我们用线段树维护这个数组,然后合并短儿子的时候顺便统计答案。
#include<cstdio>
#include<iostream>
#include<cstring>
#include<cmath>
using namespace std;
struct edge{
int to,next,w;
}ed[2000010];
int sz,head[1000010],pos[1000010],len[1000010],son[1000010],tim,w[1000010],n,L,U;
double Max[4000010],x,ad[4000010];
double z,f[1000010],g[1000010];
void add_edge(int from,int to,int w)
{
ed[++sz].to=to;
ed[sz].next=head[from];
ed[sz].w=w;
head[from]=sz;
}
void push_down(int root,int nl,int nr)
{
if(ad[root])
{
Max[root<<1]+=ad[root];
Max[root<<1|1]+=ad[root];
ad[root<<1]+=ad[root];
ad[root<<1|1]+=ad[root];
ad[root]=0;
}
}
void update(int root,int l,int r,int x,double k)
{
if(l==r)
{
Max[root]=max(Max[root],k);
return;
}
int mid=l+r>>1;
push_down(root,mid-l+1,r-mid);
if(x<=mid) update(root<<1,l,mid,x,k);
else update(root<<1|1,mid+1,r,x,k);
Max[root]=max(Max[root<<1],Max[root<<1|1]);
}
double query(int root,int l,int r,int x,int y)
{
if(x<=l&&y>=r) return Max[root];
int mid=l+r>>1;
push_down(root,mid-l+1,r-mid);
if(y<=mid) return query(root<<1,l,mid,x,y);
if(x>mid) return query(root<<1|1,mid+1,r,x,y);
return max(query(root<<1,l,mid,x,y),query(root<<1|1,mid+1,r,x,y));
}
void add(int root,int l,int r,int x,int y,double k)
{
double tmp=Max[root];
if(x<=l&&y>=r)
{
Max[root]+=k;
ad[root]+=k;
return;
}
int mid=l+r>>1;
if(x<=mid) add(root<<1,l,mid,x,y,k);
if(y>mid) add(root<<1|1,mid+1,r,x,y,k);
Max[root]=max(Max[root<<1],Max[root<<1|1]);
}
void dfs2(int u,int ff)
{
if(!pos[u]) pos[u]=++tim;
int pu=pos[u];
if(son[u])
{
dfs2(son[u],u);
add(1,1,n,pu+1,pu+len[u]-1,w[u]-x);
}
for(int i=head[u];i;i=ed[i].next)
{
int v=ed[i].to;
if(v==son[u]||v==ff) continue;
dfs2(v,u);
int pv=pos[v];
for(int j=1;j<=len[v];j++)
{
if(j+len[u]-1>=L&&j<=U)
{
double tmp=query(1,1,n,pu+max(1,L-j),pu+min(len[u]-1,U-j));
z=max(z,tmp+ed[i].w-x+query(1,1,n,pv+j-1,pv+j-1));
}
}
for(int j=1;j<=len[v];j++)
{
double tmp=query(1,1,n,pv+j-1,pv+j-1);
if(tmp+ed[i].w-x>query(1,1,n,pu+j,pu+j))
{
update(1,1,n,pu+j,tmp+ed[i].w-x);
}
}
}
if(len[u]-1>=L) z=max(z,query(1,1,n,pu+L,pu+min(U,len[u]-1)));
}
void clear(int root,int l,int r)
{
Max[root]=0;
ad[root]=0;
if(l==r) return;
int mid=l+r>>1;
clear(root<<1,l,mid);
clear(root<<1|1,mid+1,r);
}
bool check(double h)
{
clear(1,1,n);x=h;
z=-1e18;dfs2(1,1);
if(z>=-1e-7) return true;
return false;
}
void dfs1(int u,int ff)
{
len[u]=-1;
for(int i=head[u];i;i=ed[i].next)
{
int v=ed[i].to;
if(v==ff) continue;
dfs1(v,u);
if(len[v]>len[u]) len[u]=len[v],son[u]=v,w[u]=ed[i].w;
}
len[u]++;
}
int main()
{
double l=0,r=0;
scanf("%d%d%d",&n,&L,&U);
for(int i=1;i<n;i++)
{
int u,v,w;
scanf("%d%d%d",&u,&v,&w);
add_edge(u,v,w);
add_edge(v,u,w);
r+=w;
}
dfs1(1,1);
for(int i=1;i<=n;i++) len[i]++;
double ans=0;
r=1e6;
while(r-l>1e-5)
{
double mid=(l+r)/2;
if(check(mid)) l=mid;
else r=mid;
}
printf("%.3lf\n",l);
return 0;
}