【题目链接】
【思路要点】
- 问题等价于在树上选出\(K+1\)条点不相交的路径使得它们权值和最大。
- 首先考虑一个比较显然的DP。
- 记\(dp_{i,j,k}\)表示以\(i\)为根的子树中选取了\(j\)条路径,且点\(i\)的度数为\(k(k=0,1,2)\)的最大权值和。
- 但这个DP状态数太多了,显然无法通过。
- 仔细分析一下题目,我们发现,对于同一棵树,令\(f(x)\)为当\(K\)取\(x\)时问题的最优解,那么\(f\)是一个上凸函数。这是因为每加入一条路径后,答案变优的幅度一定是不增的(可以打表看看,或者感性理解一下)。
- 知道这个性质以后,我们可以二分一个斜率\(slope\),来切这个函数图像。即求出每选取一条路径需要额外失去\(slope\)点权值时的最大权值和,以及此时选取路径的条数,此时选取路径的条数就是切点的横坐标。
- 而每选取一条路径需要额外失去\(slope\)点权值时的最大权值和是容易的,我们可以在\(O(N)\)的时间内用DP求解。
- 时间复杂度\(O(NLogW)\),需要注意函数上仍然有可能存在三点共线,因此有时斜率和函数产生的不一定是切点,也有可能是公共线段。
【代码】
#include<bits/stdc++.h> using namespace std; const int MAXN = 300005; const long long INF = 4e18; template <typename T> void chkmax(T &x, T y) {x = max(x, y); } template <typename T> void chkmin(T &x, T y) {x = min(x, y); } template <typename T> void read(T &x) { x = 0; int f = 1; char c = getchar(); for (; !isdigit(c); c = getchar()) if (c == '-') f = -f; for (; isdigit(c); c = getchar()) x = x * 10 + c - '0'; x *= f; } template <typename T> void write(T x) { if (x < 0) x = -x, putchar('-'); if (x > 9) write(x / 10); putchar(x % 10 + '0'); } template <typename T> void writeln(T x) { write(x); puts(""); } struct info {int cnt; long long val; }; struct edge {int dest, len; }; info operator + (info a, info b) {return (info) {a.cnt + b.cnt, a.val + b.val}; } bool operator > (info a, info b) { if (a.val == b.val) return a.cnt < b.cnt; else return a.val > b.val; } info max(info a, info b) { if (a > b) return a; else return b; } int n, k; long long slope; info dp[MAXN][3]; vector <edge> a[MAXN]; info addpath(info x) { x.cnt += 1; x.val -= slope; return x; } info modify(info x, int cnt, long long val) { x.cnt += cnt; x.val += val; return x; } void work(int pos, int fa) { dp[pos][0] = (info) {0, 0}; dp[pos][1] = dp[pos][2] = (info) {0, -INF}; for (unsigned i = 0; i < a[pos].size(); i++) { int tmp = a[pos][i].dest, len = a[pos][i].len; if (tmp == fa) continue; work(tmp, pos); chkmax(dp[pos][2], dp[pos][2] + dp[tmp][2]); chkmax(dp[pos][2], modify(dp[pos][1] + dp[tmp][1], -1, len + slope)); chkmax(dp[pos][1], dp[pos][1] + dp[tmp][2]); chkmax(dp[pos][1], modify(dp[pos][0] + dp[tmp][1], 0, len)); chkmax(dp[pos][0], dp[pos][0] + dp[tmp][2]); } chkmax(dp[pos][1], addpath(dp[pos][0])); chkmax(dp[pos][2], max(dp[pos][0], dp[pos][1])); } info check() { work(1, 0); return dp[1][2]; } int main() { read(n), read(k); for (int i = 1; i <= n - 1; i++) { int x, y, z; read(x), read(y), read(z); a[x].push_back((edge) {y, z}); a[y].push_back((edge) {x, z}); } long long l = -1e12, r = 1e12, ans; while (l <= r) { slope = (l + r) / 2; info tmp = check(); if (tmp.cnt > k + 1) l = slope + 1; else { ans = tmp.val + slope * (k + 1); r = slope - 1; } } printf("%lld\n", ans); return 0; }