最大数据是n,m=300000,所以应该试着把时间压在n*log(xxx)以内 //在本算法下,xxx=n*∑ti
考虑修改一条边后,所有长度大于答案的路径都被缩短,因此二分答案
设二分到的答案为k
那么被删掉的边一定是所有长于k的路径的公共边之一,因此问题转化为,在O(m)时间内求出m条路径的交集中的最长边
首先,可以在O(1)时间内求出两条路径的交集:对于路径s-t 和 s'-t' 分别求出s'到s-t上最近点u,t'到s-t上最近点v,那么交集就是(u,v)
这里若无解,交集会变成s-t上一个点,结束时特判即可。在其他题目中求路径交集需要再反向求一次,在此不详细解释。
求最近点代码如下
int closest(int x,int s,int t)
{
int r=lca(s,t);
if(lca(x,r)!=r) return r;
int p=lca(x,s);
return p==r?lca(x,t):p;
}
大致思路是,若最近点不是s和t 的公共祖先,那么判断是s那侧还是t那侧
其中lca要转化为RMQ,方可O(1)实现
找到交集后,从两个顶点开始逐条边向父节点移动,直到撞上就可以了,效率很低,O(n)但不会影响整体复杂度
代码:
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<vector>
using namespace std;
#define maxm 300005
#define maxn 300005
#define maxt 600005
#define lca(u,v) rmq(pre[u],pre[v])
#define mid (l+r)/2
int n,m;
struct edge
{
int v;
int w;
int next;
edge(int v,int w,int next):v(v),w(w),next(next){}
edge(){}
}e[maxn*2];
int newedge,dfsclock;
int ind[maxn];
void addedge(int u,int v,int l)
{
e[++newedge]=edge(v,l,ind[u]);
ind[u]=newedge;
}
int f[maxn];
int pre[maxn];
int dis[maxn],dep[maxn];
int visiting[maxt];
int st[maxt][20];
int log[maxt];
void dfs(int u,int fa)
{
f[u]=fa;
dep[u]=dep[fa]+1;
pre[u]=++dfsclock;
visiting[dfsclock]=u;
for(int i=ind[u];i;i=e[i].next)
{
int v=e[i].v;
if(v!=fa)
{
dis[v]=dis[u]+e[i].w;
dfs(v,u);
visiting[++dfsclock]=u;
}
}
}
void rmq_init()
{
log[1]=0;
for(int i=2;i<=dfsclock;i++) log[i]=log[i>>1]+1;
dep[0]=0x7fffffff;
for(int i=1;i<=dfsclock;i++)
st[i][0]=visiting[i];
for(int l=1,t=0;l<=dfsclock;l<<=1,t++)
for(int i=1;i+l<=dfsclock;i++)
st[i][t+1]=dep[st[i][t]]<dep[st[i+l][t]]?st[i][t]:st[i+l][t];
}
int rmq(int l,int r)
{
if(l>r) swap(l,r);
int k=log[r-l+1];
int a=st[l][k],b=st[r-(1<<k)+1][k];
return dep[a]<dep[b]?a:b;
}
int closest(int x,int s,int t)
{
int r=lca(s,t);
if(lca(x,r)!=r) return r;
int p=lca(x,s);
return p==r?lca(x,t):p;
}
struct mission
{
int s,t;
int l;
void input()
{
scanf("%d%d",&s,&t);
l=dis[s]+dis[t]-2*dis[lca(s,t)];
}
}q[maxm];
int ln[maxm];
bool operator <(mission a,mission b)
{
return a.l<b.l;
}
bool judge(int x)
{
int p=upper_bound(ln+1,ln+m+1,x)-ln;
if(p==m+1) return true;
int ss=q[p].s,tt=q[p].t;
int maxlen=0;
for(int i=p+1;i<=m;i++)
{
ss=closest(q[i].s,q[i].t,ss);
tt=closest(q[i].s,q[i].t,tt);
}
if(ss==tt) return false;
while(ss!=tt)
{
int &a=dep[ss]>dep[tt]?ss:tt;
maxlen=max(maxlen,dis[a]-dis[f[a]]);
a=f[a];
}
if(ln[m]-maxlen<=x) return true;
else return false;
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<n;i++)
{
int a,b,t;
scanf("%d%d%d",&a,&b,&t);
addedge(a,b,t);
addedge(b,a,t);
}
dfs(1,0);
rmq_init();
for(int i=1;i<=m;i++) q[i].input();
sort(q+1,q+m+1);
for(int i=1;i<=m;i++) ln[i]=q[i].l;
int l=0,r=0x7fffffff;
while(l<r)
{
if(judge(mid)) r=mid;
else l=mid+1;
}
printf("%d",l);
}