熟悉吗?那就对了。
看见题目的一瞬间就仿佛看到了树形dp+背包。
这样的问题显然可以拆分为子树内×子树外,递推计一半回溯计一半,
同时把答案分层,递推的时候累加。
然后对于每个点又可以把代价分摊给其子结点。
假如对于某个点,简单地考虑它两边的黑/白点组合,不考虑贡献重复的话
答案为子树外黑点个数×子树内黑点距离和+子树内黑点数×子树外黑点距离和
然后再加上白点的,同理。
通常这个套路有两种选择。
一个选择是把当前点固定为点对的一端,
现在分层,就是把状态拆到各时间当前考虑的点上。
那么考虑当前点是黑还是白的,然后可以计子树内外两部分的数。
F
(
u
,
k
)
=
∑
i
∉
X
&
v
i
s
[
i
,
f
a
[
i
]
=
f
a
[
u
]
]
=
y
d
i
s
i
+
∑
F
(
v
,
j
)
F
(
u
,
k
−
j
)
,
  
y
=
0
/
1
\mathcal{F_{(u,k)}=\sum\limits_{i\notin X\& vis[i,fa[i]=fa[u]]=y} dis_i+\sum F_{(v,j)}F_{(u,k-j)},\;y=0/1}
F(u,k)=i∈/X&vis[i,fa[i]=fa[u]]=y∑disi+∑F(v,j)F(u,k−j),y=0/1
单方向计数去重,最后累加;
但是这道题限制选点个数,得背包。于是↓
另一个选择就是,分别考虑子树内的答案怎么统计上来,以及怎么计算当前点贡献。
F
o
c
u
s
\mathcal{Focus}
Focus在父向边上。由于下面要做背包,应该注意连向子结点的边。
那么将一开始简单考虑时候的答案拆分出来,就是拆分到每条边上。按照套路我们可以把距离给拆分成多条边的权值和。
然后就变成
∑
v
a
l
e
d
g
e
×
(
c
n
t
b
l
a
c
k
,
i
n
s
i
d
e
×
c
n
t
b
l
a
c
k
,
o
u
t
s
i
d
e
+
c
n
t
w
h
i
t
e
,
i
n
s
i
d
e
×
c
n
t
w
h
i
t
e
,
o
u
t
s
i
d
e
)
\mathcal{\sum val_{edge}×(cnt_{black,inside}×cnt_{black,outside}+cnt_{white,inside}×cnt_{white,outside})}
∑valedge×(cntblack,inside×cntblack,outside+cntwhite,inside×cntwhite,outside)
(当然实际考虑的时候就不是这么直接了)
当然这个也不好统计,因为子树外的状况不好枚举?
不。黑色点有且只有
K
\mathcal{K}
K个,那么子树外黑点/白点的情况在枚举的时候也可以确定。
就可以了。
转移的时候照样枚举某个子树和其它子树分配多少然后跑背包。
需要注意转移的时候并不需要考虑当前点选了什么:
因为我们实际上枚举的应该是“当前边”。
分析一下复杂度吧。
这种写法很常见的,不过经常被当做是
Θ
(
n
3
)
\mathcal{\Theta(n^3)}
Θ(n3)
实际上它是
Θ
(
n
2
)
\mathcal{\Theta(n^2)}
Θ(n2)的(至少这道题是)
因为复杂度是
n
∑
s
i
z
u
(
s
i
z
u
−
s
i
z
v
)
\mathcal{n\sum siz_u(siz_u-siz_v)}
n∑sizu(sizu−sizv)(注意写代码的时候要搞好枚举的界限不然可能会被卡)
实质上就是枚举了所有的点对
而且我们是在每个点对的
L
C
A
\mathcal{LCA}
LCA处枚举它们的(这个很容易想到)
由此可得复杂度就是枚举点对的复杂度
Θ
(
n
2
)
\mathcal{\Theta(n^2)}
Θ(n2)
……
然而转移的时候方式要正确,否则就会产生大量的重复计算。
那么这个复杂度在被给了一条链的时候会炸成
Θ
(
n
3
)
\mathcal{\Theta(n^3)}
Θ(n3)。
具体讨论
记得开 l o n g    l o n g \mathcal{long\;long} longlong。
不正确的转移(n^3)
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstdlib>
#include<cmath>
#include<cstring>
using namespace std;
#define add_edge(a, b, c) nxt[++tot] = head[a], head[a] = tot, to[tot] = b, val[tot] = c
int head[2005];
int nxt[4005];
int val[4005];
int to[4005];
long long f[2005][2005];
int siz[2005];
int n, K, tot;
void dfs(const int &x, const int &fa) {
siz[x] = 1;
for (int i = head[x]; i; i = nxt[i]) {
if (to[i] == fa) continue;
dfs(to[i], x);
siz[x] += siz[to[i]];
}
for (int i = 2; i <= min(K, siz[x]); ++i) {
f[x][i] = -1; //*
}
for (int i = head[x]; i; i = nxt[i]) {
if (to[i] == fa) continue;
for (int j = min(K, siz[x]); j >= 0; --j) {
int up = min(siz[to[i]], j);
for (int k = 0; k <= up; ++k) {
if (f[x][j-k] == -1) continue;
f[x][j] = max(f[x][j], f[x][j-k] + f[to[i]][k] + 1ll * val[i] * (1ll * k * (K - k) + 1ll * (siz[to[i]] - k) * (n - K - siz[to[i]] + k)));
}
}
}
}
int main() {
scanf("%d%d", &n, &K);
for (int u, v, w, i = 1; i < n; ++i) {
scanf("%d%d%d", &u, &v, &w);
add_edge(u, v, w);
add_edge(v, u, w);
}
dfs(1, 0);
printf("%lld", f[1][K]);
return 0;
}
n^2
这种写法不需要初始化f,因为siz的更新在求dp值之后
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstdlib>
#include<cmath>
#include<cstring>
using namespace std;
#define add_edge(a, b, c) nxt[++tot] = head[a], head[a] = tot, to[tot] = b, val[tot] = c
#define getchar() (frS==frT&&(frT=(frS=frBB)+fread(frBB,1,1<<12,stdin),frS==frT)?EOF:*frS++)
char frBB[1<<12], *frS=frBB, *frT=frBB;
int read(int &x) {
x = 0;
char ch = getchar();
while (!isdigit(ch)) ch = getchar();
while (isdigit(ch)) x = x * 10 + (ch ^ 48), ch = getchar();
return x;
}
int head[2005];
int nxt[4005];
int val[4005];
int to[4005];
long long f[2005][2005];
int siz[2005];
int n, K, tot;
void dfs(const int &x, const int &fa) {
siz[x] = 1;
for (register int i = head[x]; i; i = nxt[i]) {
if (to[i] == fa) continue;
dfs(to[i], x);
for (int j = min(K, siz[x]); j >= 0; --j) {
int up = min(siz[to[i]], K - j);
for (int k = up; k >= 0; --k) {
f[x][j+k] = max(f[x][j+k], f[x][j] + f[to[i]][k] + 1ll * val[i] * (1ll * k * (K - k) + 1ll * (siz[to[i]] - k) * (n - K - siz[to[i]] + k)));
}
}
siz[x] += siz[to[i]];
}
}
int main() {
read(n); read(K);
int u, v, w;
for (register int i = 1; i < n; ++i) {
read(u); read(v); read(w);
add_edge(u, v, w);
add_edge(v, u, w);
}
dfs(1, 0);
printf("%lld", f[1][K]);
return 0;
}