我们二分答案,如何check呢
我们要使最长的路径最短或者小于一个值,易证一定是使最长的那条路变短
那么现在我们要使长度超过mid的路径变短,我们一定是找一条在这几条路径上的
公
共
公共
公共边中最大的一条,如果最长边减去该边长度小于mid,就符合,不然就不符合
至于怎么统计路径,那就是树上差分的拿手好戏了
我们用come[i]表示点i上方边的权值,num[i]表示i点上方边被经过的次数,pick为差分数组,对于每一条长度大于mid的路径
我们差分一下就行,最后一遍dfs统计
如果num[i]等于长度大于mid的路径总数,我们统计一次(maxlen-come[i],ans)
如果最后ans<=mid return true
不然 return false;
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<cstring>
#include<queue>
using namespace std;
const int maxn = 500007;
const int N = 30;
const int INF = 2147483647;
struct node
{
int to,next,w;
}edge[2*maxn];
int cnt,head[maxn];
void add(int from,int to,int w)
{
edge[++cnt].to=to;
edge[cnt].next=head[from];
edge[cnt].w=w;
head[from]=cnt;
}
int dp[maxn][N+1],dis[maxn],Log[maxn],dep[maxn];
int n,m,mi[50],come[maxn],pick[maxn],num[maxn];
struct nodfe
{
int from,to,lca,dis;
}g[maxn];
bool cmp1(nodfe a,nodfe b){ return a.dis>b.dis ;};
void dfs1(int u,int fa)
{
dep[u]=dep[fa]+1;
for(int i=1;i<=Log[dep[u]];i++)
dp[u][i]=dp[dp[u][i-1]][i-1];
for(int i=head[u];i;i=edge[i].next)
{
int to=edge[i].to;
int w=edge[i].w;
if(to==fa)continue;
come[to]=w;
dis[to]=dis[u]+w;
dp[to][0]=u;
dfs1(to,u);
}
}
int lca(int u,int v)
{
if(dep[u]<dep[v])swap(u,v);
for(int i=N;i>=0;i--)
{
if(dep[dp[u][i]]>=dep[v]&&mi[i]<dep[u])
{
//cout<<u<<" "<<i<<endl;
u=dp[u][i];
}
}
if(u==v)return v;
for(int i=N;i>=0;i--)
{
if(dp[u][i]!=dp[v][i]&&mi[i]<dep[u])
{
u=dp[u][i],v=dp[v][i];
}
}
return dp[u][0];
}
void dfs2(int u,int fa)
{
num[u]=pick[u];
for(int i=head[u];i;i=edge[i].next)
{
int to=edge[i].to;
if(to==fa)continue;
dfs2(to,u);
num[u]+=num[to];
}
}
bool check(int mid)
{
int ret=g[1].dis,re=g[1].dis;
memset(pick,0,sizeof(pick));
memset(num,0,sizeof(num));
int number=0;
for(int i=1;i<=m;i++)
{
if(g[i].dis<=mid)break;
number++;
pick[g[i].from]++;
pick[g[i].to]++;
pick[g[i].lca]-=2;
}
//cout<<ret<<endl;
dfs2(1,0);
for(int i=1;i<=n;i++)
if(num[i]>=number)
re=min(re,ret-come[i]);//,cout<<re<<" "<<i<<" "<<come[i]<<endl;
if(re>mid)return false;
else return true;
/*for(int i=1;i<=n;i++)cout<<num[i]<<" ";
cout<<endl;*/
}
int main()
{
int l=0,r=0;
Log[0]=-1;
mi[0]=1;
for(int i=1;i<=N+1;i++)mi[i]=2*mi[i-1];
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)Log[i]=Log[i>>1]+1;
for(int i=1;i<n;i++)
{
int x,y,z;
scanf("%d%d%d",&x,&y,&z);
r+=z;
add(x,y,z),add(y,x,z);
}
dfs1(1,0);
/*for(int i=1;i<=n;i++)
{
cout<<come[i]<<" ";
//for(int j=Log[dep[i]];j>=0;j--)cout<<i<<" "<<j<<" "<<dp[i][j]<<endl;
}*/
for(int i=1;i<=m;i++)
{
int x,y;
scanf("%d%d",&x,&y);
g[i].from=x,g[i].to=y;
g[i].lca=lca(x,y);
g[i].dis=dis[x]+dis[y]-2*dis[g[i].lca];
//cout<<x<<" "<<y<<" "<<g[i].lca<<" "<<g[i].dis<<endl;
}
sort(g+1,g+1+m,cmp1);
int ans=INF;
while(l<=r)
{
int mid=l+r>>1;
if(check(mid))ans=min(ans,mid),r=mid-1;
else l=mid+1;
}
cout<<ans;
return 0;
}
/*
6 3
1 2 3
1 6 4
3 1 7
4 3 6
3 5 5
3 6
2 5
4 5
*/
/*
10 4
1 5 4
2 5 3
3 5 2
4 5 1
5 6 15161
6 10 5
6 9 6
6 8 7
6 7 8
1 10
2 9
3 8
4 7
*/