思路
本稿部分参考于嘴上神犇的博客,在此Orz %%%: 嘴上神犇的代码;
大致思路:
总的框架是二分答案;先预处理出模板LCA要处理的祖先还有距离,对于每个路线用结构体存下来,记录起点.终点.LCA以及所花费的时间;
二分答案:check函数,大致是对于当前的时间x,遍历所有的路线,将超过时间的路线记录下来,找出超过时间最多的一个;重点!!!,用sum[]数组来记录该条边被经过的次数,那么就是树上差分了,对于每条路线,将起点,终点的sum+1,LCA的sum-2,这里是以每条边的下端点作标记的(深度大的点),用一遍DFS将所有的sum[]推向它的父节点,最后都会推向根节点;
具体做法: 嘴上神犇的传送门
那么现在如果有一条边,所有超时的路线都经过并且满足删去这条边后,最大时间比x小,那么说明答案可行;如果没有这条边,答案不合法;
#include <cstdio>
#include <iostream>
#include <cstring>
#include <cmath>
using namespace std;
const int MAXN=300000+10;
int sum[MAXN],d[MAXN],p[MAXN][20],g[MAXN][20];
int head[MAXN],k[MAXN];
int n,m,num,a,b,t,r,l,ans,dis,size;
struct Edge {
int to,next,w;
}edge[MAXN<<1];
struct Que{
int s,t,w,ant;
}que[MAXN];
void read(int &in)
{
int f=1,x=0;char ch=getchar();
while(ch<'0'||ch>'9') {if(ch=='-') f=-1;ch=getchar();}
while(ch>='0'&&ch<='9') {x=x*10+ch-'0';ch=getchar();}
in=x*f;
}
void add(int from,int to,int w)
{
edge[++num].to=to;
edge[num].w=w;
edge[num].next=head[from];
head[from]=num;
}
void dfs(int u)
{
k[++size]=u;//这里是重点,k[]中深度小的在k[]序列的前段,而深度大的在后面,也就保证了下面可以用循环代替dfs的正确性
for(int i=head[u];i;i=edge[i].next)
if(!d[edge[i].to])
{
int to=edge[i].to;
d[to]=d[u]+1;
p[to][0]=u;
g[to][0]=edge[i].w;
dfs(to);
}
}
void init()
{
for(int j=1;(1<<j)<=n;j++)
for(int i=1;i<=n;i++)
if(p[i][j-1])
p[i][j]=p[p[i][j-1]][j-1],g[i][j]=g[p[i][j-1]][j-1]+g[i][j-1];
}
int lca(int a,int b)
{
dis=0;
if(d[a]>d[b]) swap(a,b);
int f=d[b]-d[a];
for(int i=0;(1<<i)<=f;i++)
if((1<<i)&f) dis+=g[b][i],b=p[b][i];
if(a==b) return a;
for(int i=(int)log2(n);i>=0;i--)
if(p[a][i]!=p[b][i])
dis+=g[a][i]+g[b][i],a=p[a][i],b=p[b][i];
dis+=g[a][0]+g[b][0];
return p[a][0];
}
int dfs2(int u)
{
for(int i=head[u];i;i=edge[i].next)
if(edge[i].to!=p[u][0])
sum[u]+=dfs2(edge[i].to);
return sum[u];
}
bool check(int x)
{
int maxdis=0,cnt=0;
memset(sum,0,sizeof sum);
for(int i=1;i<=m;i++)
{
if(que[i].w>x)
{
++sum[que[i].s],++sum[que[i].t];
sum[que[i].ant]-=2;
maxdis=max(maxdis,que[i].w-x);
cnt++;
}
}
//dfs2(1); //好吧,说实话,dfs会超时的;
for(int i=n;i>2;i--)//将所有深度深的点的sum值加至其父亲处;
sum[p[k[i]][0]]+=sum[k[i]];
for(int i=2;i<=n;i++)
if(sum[i]==cnt&&g[i][0]>=maxdis)//如果去掉g[i][0]这条边并且所有超时边都经过该点,答案合法;
return true;
return false;
}
int main()
{
read(n),read(m);
for(int i=1;i<n;i++)
{
read(a),read(b),read(t);
add(a,b,t),add(b,a,t);
r+=t;
}
d[1]=1;
dfs(1);
init();
for(int i=1;i<=m;i++)
{
read(que[i].s),read(que[i].t);
que[i].ant=lca(que[i].s,que[i].t);
que[i].w=dis;
}
while(l<=r)
{
int mid=(l+r)>>1;
if(check(mid)) ans=mid,r=mid-1;
else l=mid+1;
}
printf("%d",ans);
return 0;
}