让我来提供一个与众不同的贪心做法,不需要任何数据结构,O(nlogn)
删掉一条边的贡献就是max(不包含这条边的路径的最大值,包含这条边的路径的最大值-这条边的长度)
所以我们只需要先将路径长度从大到小排一次序,然后找到前i个路径的交际中的最大边,统计一下即可
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#include <cmath>
using namespace std;
inline void read(int &x) {
x = 0; int f = 0; char s = getchar();
while (s < '0' || '9' < s) f |= s == '-', s = getchar();
while ('0' <= s && s <= '9') x = x * 10 + s - 48, s = getchar();
x = f ? -x : x;
}
const int N = 3e5 + 10, M = 19;
int n, m, t;
int tot = 1, Head[N], ver[N << 1], Next[N << 1], Leng[N];
int fa[N][M], s[N][M], dep[N], dis[N];
struct query {
int x, y, dis, lca;
} q[N];
void add(int x, int y) {
tot++;
ver[tot] = y;
Next[tot] = Head[x];
Head[x] = tot;
}
void dfs(int x) {
for (int i = 1; (1 << i) < dep[x]; i++) {
int y = fa[x][i - 1];
fa[x][i] = fa[y][i - 1];
s[x][i] = max(s[x][i - 1], s[y][i - 1]);
}
for (int i = Head[x]; i; i = Next[i]) {
int y = ver[i];
if (fa[x][0] == y) continue;
fa[y][0] = x;
s[y][0] = Leng[i >> 1];
dep[y] = dep[x] + 1;
dis[y] = dis[x] + Leng[i >> 1];
dfs(y);
}
}
int getlca(int x, int y) {
if (dep[x] < dep[y]) swap(x, y);
for (int i = t; i >= 0; i--)
if (dep[fa[x][i]] >= dep[y])
x = fa[x][i];
if (x == y) return x;
for (int i = t; i >= 0; i--)
if (fa[x][i] != fa[y][i])
x = fa[x][i], y = fa[y][i];
x = fa[x][0];
return x;
}
bool cmp(query p1, query p2) {
return p1.dis > p2.dis;
}
int getmax(int x, int y) {
if (dep[x] < dep[y]) swap(x, y);
int ans = 0;
for (int i = t; i >= 0; i--)
if (dep[fa[x][i]] >= dep[y])
ans = max(ans, s[x][i]), x = fa[x][i];
if (x == y) return ans;
for (int i = t; i >= 0; i--)
if (fa[x][i] != fa[y][i]) {
ans = max(ans, max(s[x][i], s[y][i]));
x = fa[x][i], y = fa[y][i];
}
ans = max(ans, max(s[x][0], s[y][0]));
return ans;
}
bool cmp1(int p1, int p2) {
return dep[p1] > dep[p2];
}
int main() {
cin >> n >> m;
t = log(n) / log(2);
for (int i = 1; i < n; i++) {
int x, y, z;
read(x), read(y), read(z);
add(x, y);
add(y, x);
Leng[tot >> 1] = z;
}
dep[1] = 1; dfs(1);
int mx = 0;
for (int i = 0, j = 1; j <= m; j++) {
int x, y; read(x), read(y);
if (x == y) continue;
mx = ++i;
q[i].x = x, q[i].y = y;
q[i].lca = getlca(q[i].x, q[i].y);
q[i].dis = dis[q[i].x] - 2 * dis[q[i].lca] + dis[q[i].y];
}
m = mx;
if (!m) { puts("0"); return 0; }
sort(q + 1, q + m + 1, cmp);
mx = getmax(q[1].x, q[1].y);
int ans = max(q[2].dis, q[1].dis - mx), x = q[1].x, y = q[1].y, lc = q[1].lca;
for (int i = 2; i <= m; i++) {
int tx = x, ty = y, p[5];
p[1] = getlca(tx, q[i].x);
p[2] = getlca(tx, q[i].y);
p[3] = getlca(ty, q[i].x);
p[4] = getlca(ty, q[i].y);
sort(p + 1, p + 4 + 1, cmp1);
x = p[1], y = p[2];
if (x == y || dep[x] <= max(dep[q[i].lca], dep[lc])) break;
mx = getmax(x, y);
lc = getlca(x, y);
ans = min(ans, max(q[1].dis - mx, q[i + 1].dis));
}
cout << ans << endl;
return 0;
}