题目链接:http://www.lydsy.com/JudgeOnline/problem.php?id=3219
题目大意:给出一棵n个点n-1条边的树,求出一条包含[L,R]条边的路径,使得这条路径上的边权的中位数尽量大,输出这个答案。
数据范围:n<=100 000,1<=L<=R<=n-1
题解:因为要使得中位数尽量大,所以我们采用二分判定。对于当前二分出的答案mid,把所有边权比它小的边看成-1,边权大于等于它的边看成1,这样当某条路径的边权和大于等于0时,说明在这条路径上有不少于一半的边大于等于mid,即这条路径的中位数大于等于mid。
于是问题就转化成了,在树上找一条长度在[L,R]之间的路径,使得它的边权总和大于等于0,可以用树分治来解决。
接下来的问题是如何判定。
假设点 i 到根的距离为dep[ i ],边权和为g[ i ],那么我们应该找到一个 j 满足到根的距离dep[ j ]∈[L-dep[i],R-dep[i]]的g[ j ]的值最大的点,判断g[i]+g[j]是否大于等于0。我们发现,对于dep相同的点,我们只需要取最大的就可以了,设 f [ j ]=max(g[v]),其中dep[v]=j,对于 i ,寻找的区间应该是 f [L-dep[i]…R-dep[i]]。当前点的dep改变时,相应的寻找区间也跟着移动,当我们使用bfs遍历根的子树时,当前点的dep值是非降序的,寻找的区间的移动方向也是固定的。具体来说,当dep增大时,L-dep[i]和R-dep[i]相应减小。所以我们可以用一个单调队列来处理,当dep增加时,把dep大于R-dep[i]的点从队头删除,如果dep[i]<=L,把 f [L-dep[i]]加入队尾,然后每次判断队头+g[i]是否大于等于0就可以了。
由于每次遍历一个新的子树时,一开始应该把[maxdep,L]范围内的点存入队列(其中maxdep为之前子树的最大深度),这样如果第一个子树很大,而后面的子树很小时(这是有多坑才能有这种数据),时间复杂度可能会退化成平方级别。所以我们一开始先做一遍树分治预处理,把每一次的重心的子树按子树中最大深度从小到大排序,然后重新连边,这样遍历到一个新的子树时,maxdep的值不会大于当前子树的深度。到此,问题解决。
时间复杂度O(n*log^2n)
(ps:大视野上好像并没有十分刁钻的数据,所以不预处理排序虽然比较慢但是也能过。并且由于预处理代码实在太丑(事实上本蒟蒻的代码都很丑),所以把没有预处理的版本也贴了上来)
代码如下:
1.非预处理版本
#include <algorithm>
#include <cstdio>
const int N=100005;
int to[N*2],ne[N*2],val[N*2],v[N*2],fi[N],V[N],num[N],mx[N],fa[N],
dep[N],f[N],g[N],q[N],d[N],u[N],n,L,R,tot=0,cur,rt;
void add(int x,int y,int z){
to[++tot]=y;val[tot]=z;ne[tot]=fi[x];fi[x]=tot;
}
void findrt(int x){
num[x]=1;mx[x]=0;
for (int i=fi[x];i;i=ne[i])
if (!u[to[i]] && to[i]!=fa[x]){
fa[to[i]]=x;
findrt(to[i]);
num[x]+=num[to[i]];
mx[x]=std::max(mx[x],num[to[i]]);
}
mx[x]=std::max(mx[x],cur-num[x]);
if (mx[x]<mx[rt]) rt=x;
}
bool work(int x,int sum){
if (sum<=L) return 0;
cur=sum;rt=0;
findrt(x);u[x=rt]=1;
int maxdep=0;f[0]=0;
for (int i=fi[x];i;i=ne[i])
if (!u[to[i]]){
int s=1,t=1,head=1,tail=0;
for (int j=maxdep;j>=L;j--){
for (;head<=tail && f[d[tail]]<=f[j];tail--);
d[++tail]=j;
}
fa[q[1]=to[i]]=x;dep[to[i]]=1;g[to[i]]=v[i];
for (;s<=t;s++){
for (;head<=tail && d[head]+dep[q[s]]>R;head++);
if (dep[q[s]]<=L){
for (;head<=tail && f[d[tail]]<=f[L-dep[q[s]]];tail--);
d[++tail]=L-dep[q[s]];
}
if (head<=tail && f[d[head]]+g[q[s]]>=0) return 1;
if (dep[q[s]]>=R) continue;
for (int j=fi[q[s]];j;j=ne[j])
if (!u[to[j]] && to[j]!=fa[q[s]]){
fa[q[++t]=to[j]]=q[s];
dep[to[j]]=dep[q[s]]+1;
g[to[j]]=g[q[s]]+v[j];
}
}
maxdep=std::max(maxdep,dep[q[t]]);
for (int j=1;j<=t;j++) f[dep[q[j]]]=std::max(f[dep[q[j]]],g[q[j]]);
}
for (int i=0;i<=maxdep;i++) f[i]=-n;
for (int i=fi[x];i;i=ne[i])
if (!u[to[i]]){
if (num[to[i]]>num[x]) num[to[i]]=sum-num[x];
if (work(to[i],num[to[i]])) return 1;
}
return 0;
}
bool check(int mid){
for (int i=1;i<=tot;i++)
if (val[i]<mid) v[i]=-1;else v[i]=1;
for (int i=0;i<=n;i++) u[i]=fa[i]=0,f[i]=-n;
mx[0]=n;
return work(1,n);
}
int main(){
scanf("%d%d%d\n",&n,&L,&R);
for (int i=1;i<n;i++){
int x,y,z;
scanf("%d%d%d\n",&x,&y,&z);
add(x,y,z);add(y,x,z);V[i]=z;
}
std::sort(V+1,V+n);
int m=std::unique(V+1,V+n)-V-1;
V[0]=-1;
int l=1,r=m,ans=0,mid;
for (;l<=r;){
mid=(l+r)>>1;
if (check(V[mid])) ans=mid,l=mid+1;
else r=mid-1;
}
printf("%d\n",V[ans]);
}
2.预处理版本
#include <algorithm>
#include <cstdio>
const int N=100005;
int to[N*2],ne[N*2],val[N*2],v[N*2],fi[N],V[N],num[N],mx[N],
fa[N],dep[N],f[N],g[N],q[N],d[N],u[N],w[N],a[N],ad[N*2],
n,L,R,tot=0,cur,rt,cnt=0,now;
bool cmp(int x,int y){return ad[x]<ad[y];}
void add(int x,int y,int z){
to[++tot]=y;val[tot]=z;ne[tot]=fi[x];fi[x]=tot;
}
void findrt(int x){
num[x]=1;mx[x]=0;
for (int i=fi[x];i;i=ne[i])
if (!u[to[i]] && to[i]!=fa[x]){
fa[to[i]]=x;
findrt(to[i]);
num[x]+=num[to[i]];
mx[x]=std::max(mx[x],num[to[i]]);
}
mx[x]=std::max(mx[x],cur-num[x]);
if (mx[x]<mx[rt]) rt=x;
}
void Pre(int x,int sum){
cur=sum;rt=0;
findrt(x);u[x=rt]=1;
if (sum<=L) return;
w[++cnt]=x;
int m=0;
for (int i=fi[x];i;i=ne[i])
if (!u[to[i]]){
int s=1,t=1;
fa[q[1]=to[i]]=x;dep[to[i]]=1;
for (;s<=t;s++){
if (dep[q[s]]>=R) continue;
for (int j=fi[q[s]];j;j=ne[j])
if (!u[to[j]] && to[j]!=fa[q[s]]){
fa[q[++t]=to[j]]=q[s];
dep[to[j]]=dep[q[s]]+1;
}
}
ad[a[++m]=i]=dep[q[t]];
}
else ad[a[++m]=i]=n;
std::sort(a+1,a+m+1,cmp);
fi[x]=a[1];
for (int i=1;i<m;i++) ne[a[i]]=a[i+1];
ne[a[m]]=0;
for (int i=fi[x];i;i=ne[i])
if (!u[to[i]]){
if (num[to[i]]>num[x]) num[to[i]]=sum-num[x];
Pre(to[i],num[to[i]]);
}
}
bool work(){
if (now>=cnt) return 0;
int x=w[++now];
u[x]=1;
int maxdep=0;f[0]=0;
for (int i=fi[x];i;i=ne[i])
if (!u[to[i]]){
int s=1,t=1,head=1,tail=0;
for (int j=maxdep;j>=L;j--){
for (;head<=tail && f[d[tail]]<=f[j];tail--);
d[++tail]=j;
}
fa[q[1]=to[i]]=x;dep[to[i]]=1;g[to[i]]=v[i];
for (;s<=t;s++){
for (;head<=tail && d[head]+dep[q[s]]>R;head++);
if (dep[q[s]]<=L){
for (;head<=tail && f[d[tail]]<=f[L-dep[q[s]]];tail--);
d[++tail]=L-dep[q[s]];
}
if (head<=tail && f[d[head]]+g[q[s]]>=0) return 1;
if (dep[q[s]]>=R) continue;
for (int j=fi[q[s]];j;j=ne[j])
if (!u[to[j]] && to[j]!=fa[q[s]]){
fa[q[++t]=to[j]]=q[s];
dep[to[j]]=dep[q[s]]+1;
g[to[j]]=g[q[s]]+v[j];
}
}
maxdep=std::max(maxdep,dep[q[t]]);
for (int j=1;j<=t;j++) f[dep[q[j]]]=std::max(f[dep[q[j]]],g[q[j]]);
}
for (int i=0;i<=maxdep;i++) f[i]=-n;
for (int i=fi[x];i;i=ne[i])
if (!u[to[i]] && work()) return 1;
return 0;
}
bool check(int mid){
for (int i=1;i<=tot;i++)
if (val[i]<mid) v[i]=-1;else v[i]=1;
for (int i=0;i<=n;i++) u[i]=0,f[i]=-n;
now=0;
return work();
}
int main(){
scanf("%d%d%d\n",&n,&L,&R);
for (int i=1;i<n;i++){
int x,y,z;
scanf("%d%d%d\n",&x,&y,&z);
add(x,y,z);add(y,x,z);V[i]=z;
}
mx[0]=n;
Pre(1,n);
std::sort(V+1,V+n);
int m=std::unique(V+1,V+n)-V-1;
V[0]=-1;
int l=1,r=m,ans=0,mid;
for (;l<=r;){
mid=(l+r)>>1;
if (check(V[mid])) ans=mid,l=mid+1;
else r=mid-1;
}
printf("%d\n",V[ans]);
}