题目大意:给你一棵n个点的边带权树,再给树上m条路径,让你将一条边权改为0,使得最后路径的最大值最小
题解:
1、对于求最大值最小,首先想到二分答案,想了想发现没毛病,怎么验证;
2、考虑每次找出比当前二分的值大的路径,想将一条他们都经过的边改为0,使得所有边都小于当前二分的值;
3.于是问题转化为了,求树上路径交,我们可以选择树上差分。
4、第一次DFS时记录一下每个点的DFS序,因为一棵子树的DFS序是一段连续的区间,方便差分计算。
5、我们用一条边指向的点来代替它统计(当然也可以在边上计算)—每次求路径交就是在起点,终点++,LCA-=2;最后按DFS序加就好了
6、优化,我们可以在二分时记录当前值之前有没有算过,如果有,就直接return ,可以用一个数组记录这个答案,减小常数
#include<bits/stdc++.h>
using namespace std;
const int N = 300005;
int n,m,x,y,z,tot,dfs_sort,l,r;
int first[N],to[N*2],nxt[N*2],val[N*2],memery[N];
int deep[N],dis[N],num[N],tmp[N],v[N],f[N][20];
struct node{
int u,v,lca,len;
bool operator <(const node&a)const {
return len>a.len;
}
}c[N];
inline int Readint(){
int i=0;char c;
for(c=getchar();!isdigit(c);c=getchar());
for(;isdigit(c);c=getchar()) i=(i<<1)+(i<<3)+c-'0';
return i;
}
inline void add(int x,int y,int z){
nxt[++tot]=first[x];first[x]=tot;to[tot]=y;val[tot]=z;
nxt[++tot]=first[y];first[y]=tot;to[tot]=x;val[tot]=z;
}
inline void Dfs(int x,int fa){
num[++dfs_sort]=x;//DFS序
for(int i=1;i<=18;i++){
f[x][i]=f[f[x][i-1]][i-1];
if(!f[x][i-1]) break;
}
for(int i=first[x];i;i=nxt[i]){
if(to[i]==fa) continue;
deep[to[i]]=deep[x]+1;
dis[to[i]]=dis[x]+val[i];
v[to[i]]=val[i];//将边的值转到它指向的那个点,方便差分
f[to[i]][0]=x;
Dfs(to[i],x);
}
}
inline int lca(int a,int b){
if(deep[a]<deep[b]) swap(a,b);
int k=deep[a]-deep[b];
for(int i=18;i>=0;i--)
if(k&(1<<i)) a=f[a][i];
if(a==b) return a;
for(int i=18;i>=0;i--)
if(f[a][i]!=f[b][i]) a=f[a][i],b=f[b][i];
return f[a][0];
}
inline bool check(int x){
int cnt=0,limit=0,mx=0;
memset(tmp,0,sizeof(tmp));
for(int i=1;i<=m;i++){
if(c[i].len>x){
++tmp[c[i].u];++tmp[c[i].v];tmp[c[i].lca]-=2;
limit=max(limit,c[i].len-x);
cnt++;
}
else break;//把边从大到小排序,优化常数
}
if(memery[cnt]) return memery[cnt]>=limit;//优化
if(!cnt) return true;
for(int i=n;i>=1;i--) tmp[f[num[i]][0]]+=tmp[num[i]];
//计算一条边经过的次数
for(int i=2;i<=n;i++)
if(tmp[i]==cnt) mx=max(mx,v[i]);
memery[cnt]=mx;
return memery[cnt]>=limit;
}
int main(){
//freopen("lx.in","r",stdin);
ios::sync_with_stdio(false);
cin.tie(NULL);
n=Readint(),m=Readint();
for(int i=1;i<n;i++){
x=Readint(),y=Readint();
z=Readint(),add(x,y,z);
r+=z;
}
Dfs(1,0);
for(int i=1;i<=m;i++){
c[i].u=Readint();
c[i].v=Readint();
c[i].lca=lca(c[i].u,c[i].v);
c[i].len=dis[c[i].u]+dis[c[i].v]-2*dis[c[i].lca];
// r=max(r,c[i].len);
}
sort(c+1,c+1+m);
int ans=r;
while(l<=r){
int mid=(l+r)>>1;
if(check(mid)) ans=min(ans,mid),r=mid-1;
else l=mid+1;
}
cout<<ans;
}