题目大意: 给你一棵n个节点的树,有边权,有多个任务,每个要求从ui号节点到 vi号节点去。m 个计划, 这 m 个计划会同时开始。当这 m 个任务都完成时,工作完成。
现在可以把任意一个边的边权变为0,试求出完成工作所需要的最短时间是多少?
题解:先求出每个任务原来的所需时间,一种想法是枚举改变哪条边,但这肯定会超时(不然怎么是T3),然后我们可以想到,我们可以发现必须把最长的时间缩短,才可以把答案时间缩短,于是我们可以用二分它完成的时间,然后进行判断,把时间超过这个二分答案的任务记录一下,然后可以发现,我们要减少其中所有任务都经过的边才可以把所有任务的时间变短。
那该怎么做呢?就要用树上差分了(节点版),我们用一个数组存每个节点的差分值,它的子树的值的和就是它所要经过的次数,假设任务是从a->b,就可以s[a]++,s[b]++,s[lca(a,b)]-=2(lca要存下来,不然也会超时),具体可以baidu。
C++ Code:
#include<cstdio>
#include<cstring>
#define maxn 300100
#define maxm 21
#define inf 0x7ffffff
using namespace std;
int cnt,head[maxn];
int n,m,ans=inf;
int s[maxn],e[maxn];
int lca[maxn],deep[maxn],fa[maxn][maxm],sum[maxn],sum2[maxn];
int p[maxn];
struct Edge{
int to,nxt,cost;
}edge[maxn<<1];
void add(int a,int b,int c){
cnt++;
edge[cnt].to=b;
edge[cnt].nxt=head[a];
edge[cnt].cost=c;
head[a]=cnt;
}
void dfs(int root){
for (int i=head[root];i;i=edge[i].nxt){
int ne=edge[i].to;
if (deep[ne]==0){
deep[ne]=deep[root]+1;
sum[ne]=sum[root]+edge[i].cost;
fa[ne][0]=root;
dfs(ne);
}
}
}
void init(){
for (int i=1;i<maxm;i++){
for (int j=1;j<=n;j++){
fa[j][i]=fa[fa[j][i-1]][i-1];
}
}
}
int LCA(int x,int y){
if (deep[x]<deep[y])x^=y^=x^=y;
for (int i=maxm-1;i>=0;i--){
if (deep[fa[x][i]]>=deep[y]){
x=fa[x][i];
}
}
if (x==y)return x;
for (int i=maxm-1;i>=0;i--){
if (fa[x][i]!=fa[y][i]){
x=fa[x][i];
y=fa[y][i];
}
}
return fa[x][0];
}
void dfs1(int root){
if (head[root]==0)return;
for (int i=head[root];i;i=edge[i].nxt){
int ne=edge[i].to;
if (deep[ne]==deep[root]+1){
dfs1(ne);
p[root]+=p[ne];
}
}
}
bool check(int mid){
int maxo=0,num=0;
memset(p,0,sizeof(p));
for (int i=1;i<=n;i++){
if (sum2[i]>mid){
num++;
if (maxo<(sum2[i]-mid))maxo=sum2[i]-mid;
p[s[i]]++;
p[e[i]]++;
p[lca[i]]-=2;
}
}
dfs1(1);
for (int i=1;i<=n;i++){
if (p[i]==num){
if (sum[i]-sum[fa[i][0]]>=maxo)return 1;
}
}
return 0;
}
int main(){
scanf("%d%d",&n,&m);
for (int i=1;i<n;i++){
int a,b,c;
scanf("%d%d%d",&a,&b,&c);
add(a,b,c);
add(b,a,c);
}
deep[1]=1;
dfs(1);
init();
int r=0;
for (int i=1;i<=m;i++){
scanf("%d%d",&s[i],&e[i]);
lca[i]=LCA(s[i],e[i]);
sum2[i]=(sum[s[i]]+sum[e[i]]-(sum[lca[i]]<<1));
if(sum2[i]>r)r=sum2[i];
}
int l=0;
while (l<=r){
int mid=l+r>>1;
if (check(mid)){
ans=mid;
r=mid-1;
}else{
l=mid+1;
}
}
printf("%d\n",ans);
return 0;
}