题面
解法
神仙题
- 简化一下题面就是选出不相交的 k + 1 k+1 k+1条链,使得边权之和最大。
- 先写写部分分好了。
- k = 0 / 1 k=0/1 k=0/1比较简单,求一求直径就好了,具体细节不再赘述(我记得我当初在考场上的时候竟然用两次dfs求直径,然后因为有负数就只有5分……)
- k = 2 k=2 k=2可能是大分类讨论,表示并不会……考虑 k ≤ 100 k\leq100 k≤100怎么处理。
- 可以考虑树形dp。 f [ x ] [ i ] [ 0 ] f[x][i][0] f[x][i][0]表示根为 x x x的子树中选择 i i i条链,且 x x x这个点不在某一条链上; f [ x ] [ i ] [ 1 ] f[x][i][1] f[x][i][1]表示根为 x x x的子树中选择 i i i条链,且 x x x为链的某一个端点; f [ x ] [ i ] [ 2 ] f[x][i][2] f[x][i][2]表示根为 x x x的子树中选择 i i i条链,且 x x x这个点并不作为端点而是在路径上。
- 转移比较简单,只是在处理 f [ x ] [ i ] [ 2 ] f[x][i][2] f[x][i][2]的时候可能要稍微注意一下,就是两条路径在合并的时候最终是变成1条路径,而不是两条,所以并不能直接相减,需要+1。
- 时间复杂度: O ( n k ) O(nk) O(nk)。
- 考虑怎么处理 k ≤ n k\leq n k≤n的情况。可以发现,将 ( i , F ( i ) ) (i,F(i)) (i,F(i))看作一个点的话,那么这 k k k个点构成了一个凸壳(并不知道为什么)。
- 然后可以二分一个斜率 m i d mid mid,表示选择一条路径需要额外付出的代价为 m i d mid mid,然后再次进行树形dp,不过这一次并不需要记录相关的链的个数。dp数组 f [ x ] [ 0 / 1 / 2 ] f[x][0/1/2] f[x][0/1/2]记录一个最大值 v v v和相应的路径条数 s s s,转移类似于 60 60 60分的dp。
- 最后在二分的时候看选出的路径条数与 k k k的关系即可。
- 时间复杂度: O ( n log v ) O(n\log v) O(nlogv)
【注意事项】
- 写dp的时候一定要注意边界条件!!!
代码
#include <bits/stdc++.h>
#define ll long long
using namespace std;
template <typename T> void chkmax(T &x, T y) {x = x > y ? x : y;}
template <typename T> void chkmin(T &x, T y) {x = x < y ? x : y;}
template <typename T> void read(T &x) {
x = 0; int f = 1; char c = getchar();
while (!isdigit(c)) {if (c == '-') f = -1; c = getchar();}
while (isdigit(c)) x = x * 10 + c - '0', c = getchar(); x *= f;
}
const int N = 300010; const ll inf = 1ll << 60;
int n, K, cnt, siz[N], head[N];
struct Edge {int next, num, v;} e[N * 3];
struct Node {ll v, s;} f[N][3];
bool operator < (Node a, Node b) {return a.v == b.v ? a.s < b.s : a.v < b.v;}
Node operator + (Node a, Node b) {return {a.v + b.v, a.s + b.s};}
void add(int x, int y, int v) {
e[++cnt] = (Edge) {head[x], y, v};
head[x] = cnt;
}
void dfs(int x, int fa, ll mid) {
for (int p = head[x]; p; p = e[p].next) {
int y = e[p].num; ll v = e[p].v;
if (y == fa) continue; dfs(y, x, mid);
f[x][2] = max(f[x][2] + f[y][0], f[x][1] + f[y][1] + (Node) {v - mid, 1});
f[x][1] = max(f[x][0] + f[y][1] + (Node) {v, 0}, f[x][1] + f[y][0]);
f[x][0] = f[x][0] + f[y][0];
}
f[x][0] = max(f[x][0], max(f[x][1] + (Node) {-mid, 1}, f[x][2]));
}
void Init(ll mid) {
for (int i = 1; i <= n; i++)
f[i][0] = f[i][1] = {0, 0}, f[i][2] = {-mid, 1};
}
int main() {
read(n), read(K);
for (int i = 1; i < n; i++) {
int x, y, v; read(x), read(y), read(v);
add(x, y, v), add(y, x, v);
}
K++; ll l = -1e12, r = 1e12, ans = 0;
while (l <= r) {
ll mid = (l + r) >> 1; Init(mid), dfs(1, 0, mid);
Node tmp = f[1][0];
if (tmp.s >= K) ans = mid, l = mid + 1; else r = mid - 1;
}
Init(ans), dfs(1, 0, ans); Node tmp = f[1][0];
cout << tmp.v + 1ll * ans * K << "\n";
return 0;
}