二分答案比较容易想到,但是对于判断不太好进行,所以要使用新知识:树上差分。
对于二分出的时间,我们需要判断所有的运输计划是否在添加一个虫洞以后都达到要求,所以我们首先将所有计划按照距离排序,将大于要求的计划在树上进行差分,然后统计是否有边在所有计划中,并取其中最长的一条,判断如果将其改成虫洞是否达到要求,并更新答案。
说得比较啰嗦,看不懂就看这一篇文章吧。
吐槽:cogs最后一个点极其卡常,使用各种inline和register才卡过。
CODE:
#include<cstdio>
#include<algorithm>
using namespace std;
const int N=3e5+10;
const int INF=1e9;
struct edge
{
int nxt,to,dis;
}a[N<<1];
struct work
{
int s,t,dis,lca;
inline bool operator <(const work other)const
{
return dis>other.dis;
}
}w[N];
int head[N],f[N],deep[N],size[N],son[N],top[N],dis[N],sum[N];
int n,m,x,y,z,num,tot,L,R,mid,ans,maxcost;
inline int max(const int &a,const int &b){return a>b?a:b;}
inline int min(const int &a,const int &b){return a<b?a:b;}
inline void swap(int &a,int &b){a^=b,b^=a,a^=b;}
inline void read(int &n)
{
n=0;char c=getchar();
while(c<'0'||c>'9') c=getchar();
while(c>='0'&&c<='9') n=n*10+c-48,c=getchar();
}
inline void add(int x,int y,int z)
{
a[++num].nxt=head[x],a[num].to=y,a[num].dis=z,head[x]=num;
a[++num].nxt=head[y],a[num].to=x,a[num].dis=z,head[y]=num;
}
inline void dfs(int now)
{
size[now]=1;
int tmp=-INF;
for(register int i=head[now];i;i=a[i].nxt)
if(a[i].to!=f[now])
{
int v=a[i].to;
f[v]=now;
deep[v]=deep[now]+1;
dis[v]=dis[now]+a[i].dis;
dfs(v);
size[now]+=size[v];
if(size[v]>tmp) tmp=size[v],son[now]=v;
}
}
inline void dfs2(int now,int high)
{
top[now]=high;
if(son[now]) dfs2(son[now],high);
for(register int i=head[now];i;i=a[i].nxt)
if(a[i].to!=f[now]&&a[i].to!=son[now]) dfs2(a[i].to,a[i].to);
}
inline int LCA(int x,int y)
{
while(top[x]!=top[y])
if(deep[top[x]]>deep[top[y]]) x=f[top[x]];
else y=f[top[y]];
return deep[x]<deep[y]?x:y;
}
inline int calc(int now)
{
for(register int i=head[now];i;i=a[i].nxt)
if(a[i].to!=f[now]) sum[now]+=calc(a[i].to);
if(sum[now]==tot) maxcost=max(maxcost,dis[now]-dis[f[now]]);
int ans=sum[now];sum[now]=0;
return ans;
}
inline bool check(int dist)
{
tot=m;
for(register int i=1;i<=m;i++)
if(w[i].dis>dist) tot=i,sum[w[i].s]++,sum[w[i].t]++,sum[w[i].lca]-=2;
maxcost=0;
calc(1);
if(w[1].dis-maxcost<=dist) return 1;
return 0;
}
int main()
{
read(n),read(m);
for(register int i=1;i<n;i++)
read(x),read(y),read(z),add(x,y,z);
dfs(1),dfs2(1,1);
for(register int i=1;i<=m;i++)
{
read(w[i].s),read(w[i].t);
w[i].lca=LCA(w[i].s,w[i].t);
w[i].dis=dis[w[i].s]+dis[w[i].t]-(dis[w[i].lca]<<1);
R=max(R,w[i].dis);
}
sort(w+1,w+m+1);
while(L<=R)
{
int mid=(L+R)>>1;
if(check(mid)) ans=mid,R=mid-1;
else L=mid+1;
}
printf("%d",ans);
return 0;
}