看到题意描述第一反应就是先二分那个修建的 mm 条赛道中长度最小的赛道的长度 k ,然后 O(n)O(n) 或 O(n\log n)O(nlogn) 判断。
那么怎么判断呢?
对于每个结点,把所有传上来的值 valval 放进一个 multisetmultiset ,其实这些值对答案有贡献就两种情况:
val≥k
val_a+val_b≥k
那么第一种情况可以不用放进 multiset,直接答案 +1 就好了。第二种情况就可以对于每一个最小的元素,在 multiset中找到第一个+val_b≥k 的数,将两个数同时删去,最后把剩下最大的值传到那个结点的父亲
因此这道题本质上是树的dfs和二分答案的结合。主要注意的是每个结点的子节点计算val_a+val_b≥k时,val_b尽量小,因而使用lower_bound
#include <cstdio>
#include <iostream>
#include <cstring>
#include <algorithm>
#include <vector>
#include <set>
#include <cmath>
using namespace std;
const int maxn = 50010;
struct node
{
int to;
int next;
int val;
} a[maxn << 1];
int n, m, head[maxn], tot = 0, ans, up;
multiset<int> setson[maxn];
inline int read()
{
int x = 0, y = 1; char c = getchar();
while (c>'9' || c<'0'){ if (c == '0')y = -1; c = getchar(); }
while (c >= '0'&&c <= '9'){ x = x * 10 + c - '0'; c = getchar(); }
return x*y;
}
void add(int x, int y, int w)
{
a[++tot].to = y;
a[tot].next = head[x];
a[tot].val = w;
head[x] = tot;
}
int getl(int x, int fa)
{
int sum1 = 0, sum2 = 0;
for (int i = head[x], y; i; i = a[i].next)
{
y = a[i].to;
if (y == fa)continue;
sum2 = max(sum2, (getl(y, x)+a[i].val));
if (sum1 < sum2)swap(sum1, sum2);
}
up = max(up, sum1 + sum2);
return sum1;
}
int dfs(int x, int fa, int mid)
{
setson[x].clear();
int val;
for (int i = head[x], y; i; i = a[i].next)
{
y = a[i].to;
if (y == fa)continue;
val = dfs(y, x, mid) + a[i].val;
if (val >= mid)ans++;
else
{
setson[x].insert(val);
}
}
multiset<int>::iterator it;
int MAX = 0;
while (!setson[x].empty())
{
if (setson[x].size() == 1){
return max(MAX, *setson[x].begin());
}
it = setson[x].lower_bound(mid - *setson[x].begin());
if (it == setson[x].begin() && setson[x].count(*it) == 1) it++;
if (it != setson[x].end())
{
ans++;
setson[x].erase(it);
setson[x].erase(setson[x].begin());
}
else if (it == setson[x].end()){
MAX = max(MAX, *setson[x].begin());
setson[x].erase(setson[x].find(*setson[x].begin()));
}
}
return MAX;
}
bool check(int mid)
{
ans = 0;
dfs(1, 0, mid);
if (ans >= m)return true;
return false;
}
int main()
{
n = read(), m = read();
int x, y, w;
for (int i = 1; i < n; ++i)
{
x = read();
y = read();
w = read();
add(x, y, w);
add(y, x, w);
}
getl(1, 0);
int L = 1, R = up, mid;
while (L < R)
{
mid = (L + R + 1) / 2;
if (check(mid))L = mid;
else
{
R = mid-1;
}
}
cout << L;
return 0;
}