题目大意
给定一棵树,
n
−
1
n-1
n−1条边。问如何在边上填数(范围从0到n-1,且每个数仅出现一次)使得
S
S
S 最小
S
=
∑
1
≤
u
,
v
≤
n
m
e
x
(
u
,
,
v
)
S = \sum_{1\leq u,v \leq n}mex(u,,v)
S=1≤u,v≤n∑mex(u,,v)
其中
m
e
x
(
u
,
v
)
mex(u, v)
mex(u,v)表示从u到v的路径上最小没出现的自然数。
数据范围:
2
≤
n
≤
3000
2 \leq n \leq 3000
2≤n≤3000
解题思路
假设我们已经知道填0的边所在的位置(边
<
u
,
v
>
<u,v>
<u,v>),那么填1的边与u或者v相连一定最优。如果找到一条边<u, w>填1,那么填2的边一定与w或者v相连最优,以此类推,可以发现我们似乎就在维护一个链,而这条链一但确定,其它边不管怎么填都不会影响答案。
再考虑如果我们在上面的链上加入一条边时答案该如何变化。如果加入一条边后,答案会增加u,v子树中所有点对的数量,也就是子树大小的乘积。
这就把问题转化成了一个类似区间dp的问题。如果我们想求链从u到v时的答案
f
[
u
]
[
v
]
f[u][v]
f[u][v],那么
f
[
u
]
[
v
]
=
m
a
x
(
f
[
f
a
[
v
]
[
u
]
]
[
v
]
,
f
[
u
]
[
f
a
[
u
]
[
v
]
]
)
+
s
u
m
[
u
]
[
v
]
∗
s
u
m
[
v
]
[
u
]
f[u][v] = max(f[fa[v][u]][v], f[u][fa[u][v]]) +sum[u][v]*sum[v][u]
f[u][v]=max(f[fa[v][u]][v],f[u][fa[u][v]])+sum[u][v]∗sum[v][u]
其中
f
a
[
u
]
[
v
]
fa[u][v]
fa[u][v]表示在以u为根的树中,v的父节点,
s
u
m
[
u
]
[
v
]
sum[u][v]
sum[u][v]表示在以u为根结点的子树中,v结点子树的大小。这两个通过
O
(
n
2
)
O(n^2)
O(n2)的预处理就能得到。
该dp方程类似于区间dp的方程,
f
[
i
]
[
j
]
=
m
a
x
(
f
[
i
−
1
]
[
j
]
,
f
[
i
]
[
j
−
1
]
)
+
w
[
i
]
[
j
]
f[i][j]=max(f[i-1][j], f[i][j-1])+w[i][j]
f[i][j]=max(f[i−1][j],f[i][j−1])+w[i][j]
此题可以看作区间dp在树上的一种变形。
代码实现
#include <bits/stdc++.h>
using namespace std;
using LL = long long;
const int MAXN = 6005;
int head[MAXN], nxt[MAXN], to[MAXN], sze, n;
inline void AddEdge(int u, int v) {
nxt[++sze] = head[u]; to[head[u] = sze] = v;
}
int sum[MAXN][MAXN], fa[MAXN][MAXN];
LL f[MAXN / 2][MAXN / 2]; // 不除2会炸内存
void dfs(int root, int u, int faa) {
fa[root][u] = faa;
sum[root][u] = 1;
for (int e = head[u]; e; e = nxt[e]) {
if (to[e] == faa) continue;
dfs(root, to[e], u);
sum[root][u] += sum[root][to[e]];
}
}
LL dp(int u, int v) {
if (u == v) return f[u][v] = 0;
if (u > v) swap(u, v);
if (f[u][v] != -1) return f[u][v];
f[u][v] = max(dp(u, fa[u][v]), dp(v, fa[v][u])) + (LL)sum[u][v] * sum[v][u];
return f[u][v];
}
int main() {
scanf("%d", &n);
for (int i = 1; i < n; i++) {
int u, v; scanf("%d%d", &u, &v);
AddEdge(u, v); AddEdge(v, u);
}
for (int i = 1; i <= n; i++)
dfs(i, i, i);
memset(f, -1, sizeof(f));
LL ans = 0;
for (int i = 1; i <= n; i++)
for (int j = 1; j <= n; j++)
ans = max(ans, dp(i, j));
printf("%lld", ans);
return 0;
}