题目链接:点我啊╭(╯^╰)╮
题目大意:
一颗带正负边权的树
删去
k
k
k 条边后,再任意加上
k
k
k 条边权为
0
0
0 的边
问任意两点简单路径的最大值
解题思路:
题目可以转化为:
求不相交的
k
+
1
k+1
k+1 条链的最大边权和
那么树形
d
p
dp
dp 即可,时间复杂度为
O
(
n
k
2
)
O(nk^2)
O(nk2)
设
d
p
[
i
]
[
j
]
[
k
]
dp[i][j][k]
dp[i][j][k] 为以
i
i
i 为根节点的子树,不相交的
j
j
j 条链,且
i
i
i 的链度数为
k
k
k 的最大值
d
p
[
i
]
[
j
]
[
2
]
dp[i][j][2]
dp[i][j][2] 表示
i
i
i 为链的路径
d
p
[
i
]
[
j
]
[
1
]
dp[i][j][1]
dp[i][j][1] 表示
i
i
i 为链的端点
d
p
[
i
]
[
j
]
[
0
]
dp[i][j][0]
dp[i][j][0] 表示
i
i
i 与链不相交
然后用
w
q
s
wqs
wqs 二分优化掉
d
p
dp
dp 里的
k
k
k
时间复杂度就降为了
O
(
n
)
O(n)
O(n)
注意要一开始就对
d
p
[
i
]
[
1
]
dp[i][1]
dp[i][1] 与
d
p
[
i
]
[
2
]
dp[i][2]
dp[i][2] 赋值,表示增加一条链的代价
两条链合并为一条链时,原来要
−
-
− 斜率
k
k
k
因为一开始赋了值,因此这里变成
+
+
+ 斜率
k
k
k
。。但是我看有些博客是不对
d
p
[
i
]
[
1
]
dp[i][1]
dp[i][1] 赋值,合并
−
-
− 斜率
k
k
k,感觉不是那么容易理解。。
// https://loj.ac/problem/2478
// https://oj.lpoj.cn/problemdetail?problemID=5793
#include<bits/stdc++.h>
#define rint register int
#define deb(x) cerr<<#x<<" = "<<(x)<<'\n';
using namespace std;
typedef long long ll;
typedef pair <int,int> pii;
const int maxn = 3e5 + 5;
int n, k, cnt;
vector <pii> g[maxn];
ll mid, tot;
struct node{
ll v, num;
bool operator < (const node &A) const {
return v == A.v ? num > A.num : v < A.v;
}
node operator + (const node &A) const {
return {v + A.v, num + A.num};
}
node operator + (int A) const {
return {v + A, num};
}
} dp[maxn][3];
void dfs(int u, int fa){
for(auto tv : g[u]){
int v = tv.first, w = tv.second;
if(v == fa) continue;
dfs(v, u);
dp[u][2] = max(dp[u][2] + dp[v][0], dp[u][1] + dp[v][1] + w + (node){mid, -1});
dp[u][1] = max(dp[u][1] + dp[v][0], dp[u][0] + dp[v][1] + w);
dp[u][0] = dp[u][0] + dp[v][0];
}
dp[u][0] = max(dp[u][0], max(dp[u][1], dp[u][2]));
}
bool ck(ll x){
for(int i=1; i<=n; i++)
dp[i][1] = dp[i][2] = (node){-x, 1}, dp[i][0] = (node){0, 0};
dfs(1, 0);
return dp[1][0].num <= k;
}
signed main() {
scanf("%d%d", &n, &k); k++;
for(int i=1; i<n; i++){
int u, v, w;
scanf("%d%d%d", &u, &v, &w);
g[u].push_back({v, w});
g[v].push_back({u, w});
tot += abs(w);
}
ll l = -tot, r = tot;
while(l <= r){
mid = l + r >> 1;
if(ck(mid)) r = mid - 1;
else l = mid + 1;
}
ck(mid = l);
printf("%lld", dp[1][0].v + l * k);
}