** [八省联考2018]林克卡特树lct**
分析
很神仙的一道wqs二分。是真的不会切>-<
如果已经切完了,最优秀的方案就是每个联通块搞直径然后连起来一定是最优的。
换句话说,我们要在树上选择k+1条不同的链,使得这些链的长度之和最长。
转化了一步之后,我们就可以Dp了。
对于这种树上的链的Dp,一般的方法是分单链,过根链和子树链三种
设
f
[
0
/
1
/
2
]
[
i
]
[
k
]
f[0/1/2][i][k]
f[0/1/2][i][k]分别表示子树链,单链,过根链。
注意单链暂时不算链,合并的时候再算
f
[
0
]
[
u
]
[
k
]
=
f
[
0
]
[
u
]
[
j
]
+
f
[
0
]
[
s
o
n
]
[
k
−
j
]
f[0][u][k]=f[0][u][j]+f[0][son][k - j]
f[0][u][k]=f[0][u][j]+f[0][son][k−j]
f
[
1
]
[
u
]
[
k
]
=
max
{
f
[
1
]
[
u
]
[
j
]
+
f
[
0
]
[
s
o
n
]
[
k
−
j
]
,
f
[
0
]
[
u
]
[
j
]
+
f
[
1
]
[
s
o
n
]
[
k
−
j
]
+
w
}
f[1][u][k]=\max \{f[1][u][j]+f[0][son][k - j], f[0][u][j] + f[1][son][k-j] + w\}
f[1][u][k]=max{f[1][u][j]+f[0][son][k−j],f[0][u][j]+f[1][son][k−j]+w}
f
[
2
]
[
u
]
[
k
]
=
max
{
f
[
2
]
[
u
]
[
j
]
+
f
[
0
]
[
s
o
n
]
[
k
−
j
]
,
f
[
1
]
[
u
]
[
j
−
1
]
+
f
[
1
]
[
s
o
n
]
[
k
−
j
]
+
w
}
f[2][u][k]=\max \{f[2][u][j]+f[0][son][k - j], f[1][u][j-1]+f[1][son][k-j]+w\}
f[2][u][k]=max{f[2][u][j]+f[0][son][k−j],f[1][u][j−1]+f[1][son][k−j]+w}
然后一个神奇的结论是,
(
f
[
0
/
1
/
2
]
[
u
]
[
k
]
)
(f[0/1/2][u][k])
(f[0/1/2][u][k])是随
k
k
k上凸的,于是我们让选择一条链的时候付出一点“代价”(wqs的套路)
然后看它选了多少条链二分即可。
注意选链的时候要今年往少的选。
代码
#include<bits/stdc++.h>
const int N = 3e5 + 10;
int ri() {
char c = getchar(); int x = 0, f = 1; for(;c < '0' || c > '9'; c = getchar()) if(c == '-') f = -1;
for(;c >= '0' && c <= '9'; c = getchar()) x = (x << 1) + (x << 3) - '0' + c; return x * f;
}
int to[N << 1], nx[N << 1], pr[N], w[N << 1], tp, n, k;
long long R, m;
void add(int u, int v, int _w) {to[++tp] = v; nx[tp] = pr[u]; pr[u] = tp; w[tp] = _w;}
void adds(int u, int v, int w) {add(u, v, w); add(v, u, w); R += abs(w);}
struct Data {
long long f; int k;
Data(long long _f = 0, int _k = 0) : f(_f), k(_k) {}
Data operator + (const Data &a) {return Data(f + a.f, k + a.k);}
}f[N][3];
Data add(Data a) {return Data(a.f - m, a.k + 1);}
Data max(Data a, Data b) {return (a.f == b.f ? a.k < b.k : a.f > b.f) ? a : b;}
void Dp(int u, int fa) {
f[u][2] = f[u][1] = f[u][0] = Data(0, 0);
f[u][2] = max(f[u][2], Data(-m, 1));
for(int i = pr[u]; i; i = nx[i])
if(to[i] != fa) {
Dp(to[i], u);
f[u][2] = max(f[u][2] + f[to[i]][0], add(f[u][1] + f[to[i]][1] + Data(w[i], 0)));
f[u][1] = max(f[u][1] + f[to[i]][0], f[u][0] + f[to[i]][1] + Data(w[i], 0));
f[u][0] = f[u][0] + f[to[i]][0];
}
f[u][0] = max(f[u][0], max(add(f[u][1]), f[u][2]));
}
int main() {
n = ri(); k = ri() + 1;
for(int i = 1, u, v;i < n; ++i)
u = ri(), v = ri(), adds(u, v, ri());
long long L = -R;
for(;L <= R;) {
m = L + R >> 1;
Dp(1, 0);
if(f[1][0].k <= k) R = m - 1;
else L = m + 1;
}
m = L;
Dp(1, 0);
printf("%lld\n", f[1][0].f + m * k);
return 0;
}