题意
给定一棵 n n n个节点的树,给定 m m m条路径,现在将树上一条边的边权变成 0 0 0,使得这 m m m条路径的最大值最小。
数据范围
1 ≤ n , m ≤ 300000 1 \leq n, m \leq 300000 1≤n,m≤300000
思路
考虑二分答案。每次判断的时候,考察长度大于
m
i
d
mid
mid的路径,我们需要删掉所有这样的路径都经过的边。如果删掉的边的权重大于等于最长路径与
m
i
d
mid
mid之差,那么
m
i
d
mid
mid就是可行的。
那么如何维护所有长度大于
m
i
d
mid
mid的路径共同经过的边呢?可以通过树上差分来实现,如果当前路径长度大于
m
i
d
mid
mid,那么将这条路径上的所有边
+
1
+1
+1。最后找被加次数等于长度大于
m
i
d
mid
mid的路径的数量即可。
代码
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <queue>
using namespace std;
typedef pair<int,int> pii;
const int N = 300010, M = 2 * N;
int n, m;
int h[N], e[M], ne[M], w[M], idx;
int fa[N][20], depth[N], dist[N];
int blca[N];
queue<int> que;
int sum[N];
pii trans[N];
void add(int a, int b, int c)
{
e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx ++;
}
void bfs()
{
memset(depth, 0x3f, sizeof(depth));
depth[0] = 0, depth[1] = 1;
que.push(1);
while(que.size()) {
int t = que.front();
que.pop();
for(int i = h[t]; ~i; i = ne[i]) {
int j = e[i];
if(depth[j] > depth[t] + 1) {
depth[j] = depth[t] + 1;
dist[j] = dist[t] + w[i];
que.push(j);
fa[j][0] = t;
for(int k = 1; k < 20; k ++) {
fa[j][k] = fa[fa[j][k-1]][k-1];
}
}
}
}
}
int lca(int a, int b)
{
if(depth[a] < depth[b]) swap(a, b);
for(int k = 19; k >= 0; k --) {
if(depth[fa[a][k]] >= depth[b]) {
a = fa[a][k];
}
}
if(a == b) return a;
for(int k = 19; k >= 0; k --) {
if(fa[a][k] != fa[b][k]) {
a = fa[a][k];
b = fa[b][k];
}
}
return fa[a][0];
}
void dfs_sum(int u, int fa)
{
for(int i = h[u]; ~i; i = ne[i]) {
int j = e[i];
if(j == fa) continue;
dfs_sum(j, u);
sum[u] += sum[j];
}
}
bool check(int mid)
{
memset(sum, 0, sizeof(sum));
int s = 0, maxd = 0;
for(int i = 0; i < m; i ++) {
int a = trans[i].first, b = trans[i].second;
int p = blca[i];
int d = dist[a] + dist[b] - 2 * dist[p];
if(d > mid){
sum[a] += 1, sum[b] += 1, sum[p] -= 2;
maxd = max(maxd, d - mid);
s ++;
}
}
if(s == 0) return true;
dfs_sum(1, -1);
for(int i = 1; i <= n; i ++) {
if(sum[i] == s && dist[i] - dist[fa[i][0]] >= maxd) {
return true;
}
}
return false;
}
int main()
{
scanf("%d%d", &n, &m);
memset(h, -1, sizeof(h));
for(int i = 1; i < n; i ++) {
int a, b, c;
scanf("%d%d%d", &a, &b, &c);
add(a, b, c), add(b, a, c);
}
bfs();
for(int i = 0; i < m; i ++) {
int a, b;
scanf("%d%d", &a, &b);
blca[i] = lca(a, b);
trans[i] = {a, b};
}
int l = 0, r = 1e9;
while(l < r) {
int mid = l + r >> 1;
if(check(mid)) r = mid;
else l = mid + 1;
}
printf("%d\n", r);
return 0;
}