题意
一个树上有 N ( N ≤ 500 ) N(N\leq500) N(N≤500)个节点,每个节点与其父节点有正权值,表示距离。你需要回答 Q ( Q ≤ 1000 ) Q(Q\leq1000) Q(Q≤1000)个询问,每个询问给出个 x x x,求从根节点出发走不超过 x x x距离最多能走到多少个节点。
解题思路
先想暴力解法,从根节点出发,对于任意一个节点,你可以回到根节点或者选择任何一个子节点往下走,把所有可能路线都走一遍直到
x
x
x耗尽。
我们可以发现我们的决策十分有限而且有一些重复状态,所以很容易想到用
d
p
dp
dp去求解。由于
x
x
x很大,所以把经过的路程计入状态是不行的,但是我们可以反着来,记录经过
k
k
k个点最少需要多少路程。于是我们有状态
d
p
[
u
]
[
k
]
[
0
/
1
]
dp[u][k][0/1]
dp[u][k][0/1]代表节点
u
u
u以及其子树经过
k
k
k个节点,
1
1
1表示回到
u
u
u,
0
0
0表示不回所需要的最短路程。接下来考虑如何转移。
对于节点
u
u
u,我们可以往
u
u
u的一个子节点
v
v
v走。如果返回
u
u
u那么
v
v
v一定要先返回,而且我们能走这条路的前提是其他子节点也都返回过
u
u
u。经过点数的更新其实就是个树上分组背包问题,
j
j
j代表子节点选了多少个,于是我们有
d
p
[
u
]
[
i
]
[
1
]
=
m
i
n
(
d
p
[
u
]
[
i
]
[
1
]
,
d
p
[
u
]
[
i
−
j
]
[
1
]
+
d
p
[
v
]
[
j
]
[
1
]
+
2
w
u
,
v
)
dp[u][i][1]=min(dp[u][i][1], dp[u][i-j][1]+dp[v][j][1]+2w_{u,v})
dp[u][i][1]=min(dp[u][i][1],dp[u][i−j][1]+dp[v][j][1]+2wu,v)
如果不返回
u
u
u,那么遍历就会停在
v
v
v这个节点,因为你不回到
u
u
u就没办法继续往下走。还有一种情况就是你自己回到
u
u
u,这就允许其他子节点不回到
u
u
u,那么就是:
d
p
[
u
]
[
i
−
j
]
[
0
]
+
d
p
[
v
]
[
j
]
[
1
]
+
2
w
u
,
v
,
d
p
[
u
]
[
i
−
j
]
[
1
]
+
d
p
[
v
]
[
j
]
[
0
]
+
w
u
,
v
dp[u][i-j][0]+dp[v][j][1]+2w_{u,v},\\ dp[u][i-j][1]+dp[v][j][0]+w_{u,v}
dp[u][i−j][0]+dp[v][j][1]+2wu,v,dp[u][i−j][1]+dp[v][j][0]+wu,v
当然你也可以只走这个子节点,这就等同于下面的情况。只要你把
d
p
[
u
]
[
1
]
[
1
/
0
]
dp[u][1][1/0]
dp[u][1][1/0]设为0,这就是
d
p
dp
dp的边界情况,
k
k
k是不会为0的,因为无论如何自己你是能达到的。这就是个非法状态。
最后求答案的时候只要二分一下满足条件的 k k k就好了。
时间复杂度
O ( N 3 + Q log N ) O(N^3+Q\log{N}) O(N3+QlogN)
核心代码
void dfs(int u) {
sz[u] = 1;
dp[u][1][1] = dp[u][1][0] = 0;
for (auto v : G[u]) {
dfs(v.to);
sz[u] += sz[v.to];
}
for (auto v : G[u]) {
for (int i = sz[u]; i >= 2; i--) {
for (int j = 1; j <= min(i, sz[v.to]); j++) {
dp[u][i][0] =
min(dp[u][i][0], dp[u][i - j][1] + dp[v.to][j][0] + v.dis);
dp[u][i][0] = min(dp[u][i][0],
dp[u][i - j][0] + dp[v.to][j][1] + v.dis * 2);
dp[u][i][1] = min(dp[u][i][1],
dp[u][i - j][1] + dp[v.to][j][1] + v.dis * 2);
}
}
}
}
void solve() {
for (int i = 0; i < n; i++) {
G[i].clear();
}
memset(dp, 0x3f, sizeof(dp));
for (int i = 0; i < n - 1; i++) {
int u, v, w;
scanf("%d%d%d", &u, &v, &w);
G[v].push_back({u, w});
}
dfs(0);
vector<int> res;
res.push_back(0);
for (int i = 2; i <= n; i++) {
res.push_back(dp[0][i][0]);
}
int q;
scanf("%d", &q);
printf("Case %d:\n", CASE++);
while (q--) {
int x;
scanf("%d", &x);
auto p = upper_bound(res.begin(), res.end(), x);
if (p != res.begin()) --p;
int ans = p - res.begin() + 1;
printf("%d\n", ans);
}
}