先二分 可行的答案mid 找到所有比mid大的路线,然后用一种叫什么差分的东西搞一搞可以做出每个点要被几种不同的路径走过。具体看代码吧。
#include <cstdio>
#include <cmath>
#include <ctime>
#include <string>
#include <cstring>
#include <cstdlib>
#include <iostream>
#include <algorithm>
#include <vector>
#define pb push_back
#define forup(i,a,b) for(int i=(a);i<=(b);i++)
#define fordown(i,a,b) for(int i=(a);i>=(b);i--)
#define maxn 300005
#define maxm 100005
#define INF 1070000000
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
template<class T> inline
void read(T& num){ num = 0; bool f = true;char ch = getchar(); while(ch < '0' || ch > '9') { if(ch == '-') f = false;ch = getchar();} while(ch >= '0' && ch <= '9') {num = num * 10 + ch - '0';ch = getchar();} num = f ? num: -num; }
int out[100];
template<class T> inline
void write(T x,char ch){ if (x==0) {putchar('0'); putchar(ch); return;} if (x<0) {putchar('-'); x=-x;}int num=0; while (x){ out[num++]=(x%10); x=x/10;} fordown(i,num-1,0) putchar(out[i]+'0'); putchar(ch); }
/*==================split line==================*/
int fa[maxn],siz[maxn],sonh[maxn],topic[maxn],q[maxn],depth[maxn],dis[maxn];
struct Edge
{int to,len;};vector<Edge> g[maxn];
int n,m;
int sum;
int cnt;
int len[maxn];
int dfs(int x)
{ siz[x]=1;cnt++;q[cnt]=x;
int cmax=0,k=0;
for(int i=0;i<g[x].size();i++)
{ int u=g[x][i].to;
if(u!=fa[x])
{ fa[u]=x;
depth[u]=depth[x]+1;
dis[u]=dis[x]+g[x][i].len;
siz[u]=dfs(u);
siz[x]+=siz[u];
if(cmax<siz[u]) cmax=siz[u],k=u;
}
}
sonh[x]=k;
return siz[x];
}
void mark(int x,int head)
{ topic[x]=head;
if(sonh[x]) mark(sonh[x],head);
for(int i=0;i<g[x].size();i++)
{int u=g[x][i].to;
if(u!=fa[x]&&u!=sonh[x]) mark(u,u);
}
}
int lca(int x,int y)
{ if(x==y) return x;
while(topic[x]!=topic[y])
{ if(depth[topic[x]]<depth[topic[y]]) swap(x,y);
x=topic[x];x=fa[x];
}
if(depth[x]<depth[y]) swap(x,y);
return y;
}
int query[maxn][3];
int s[maxn];
bool pd(int mid)
{ memset(s,0,sizeof(s));
int cnt=0,lim=0;
forup(i,1,m) {if(len[i]>mid){cnt++;lim=max(lim,len[i]-mid);s[query[i][0]]++;s[query[i][1]]++;s[query[i][2]]-=2;}}
if(!cnt) return 1;
fordown(i,n,1) s[fa[q[i]]]+=s[q[i]];
forup(i,1,n)
for(int j=0;j<g[i].size();j++)
{int u=g[i][j].to;
if(u!=fa[i]&&s[u]==cnt&&g[i][j].len>=lim) return 1;
}
return 0;
}
int main()
{ cin>>n>>m;
forup(i,1,n-1){ int x,y,z;read(x);read(y);read(z); g[x].pb((Edge){y,z});g[y].pb((Edge){x,z});sum+=z;};
dfs(1);
mark(1,1);
forup(i,1,m) {int x,y; read(x);read(y);query[i][0]=x;query[i][1]=y;query[i][2]=lca(x,y);len[i]=dis[x]+dis[y]-2*dis[query[i][2]];}
int l=1,r=sum;
while(r-l>1)
{ int mid=(l+r)>>1;
if(pd(mid)) r=mid;
else l=mid;
}
if(pd(r)) cout<<r;
else cout<<l;
return 0;
}