AcWing 264. 权值
题意
给一棵 n n n 个节点的树,边带权。求一条简单路径,使得这条路径上边的权值之和为 k k k ,且包含边的数量最少。
解法
树上路径询问?立即推:点分治!
- 枚举每个点作为
lca
,维护所有子树到这个点的距离的最小深度,遍历每一棵子树,先更新答案,再更新最小深度,遍历完所有子树之后dfs
清空最小深度为 i n f inf inf; - 注意距离有可能很大,但是需要存储小于等于 k k k 的最小深度即可;
-
m
i
n
n
[
0
]
=
0
minn[0]=0
minn[0]=0 表示路径的其中一个端点就是
lca
; - 第一发T了,把
memset
换成循环就AC了,当然用前向星应该也可以AC。
代码
#pragma region
#include <algorithm>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <map>
#include <queue>
#include <set>
#include <vector>
using namespace std;
typedef long long ll;
#define rep(i, a, n) for (int i = a; i <= n; ++i)
#define per(i, a, n) for (int i = n; i >= a; --i)
#pragma endregion
const int maxn = 2e5 + 5;
const int inf = 0x3f3f3f3f;
int n, k;
vector<pair<int, int>> g[maxn];
int sz[maxn], rt, dep[maxn], minn[int(1e6 + 5)];
ll d[maxn];
bool vis[maxn];
void dfs_rt(int u, int f, int tot) {
sz[u] = 1;
int maxx = 0;
for (auto e : g[u]) {
int v = e.first;
if (vis[v] || v == f) continue;
dfs_rt(v, u, tot);
sz[u] += sz[v];
maxx = max(maxx, sz[v]);
}
maxx = max(maxx, tot - sz[u]);
if (maxx * 2 <= tot) rt = u;
}
int cnt;
void dfs_ans(int u, int f, int &ans) {
++cnt;
if (d[u] <= k) ans = min(ans, dep[u] + minn[k - d[u]]);
for (auto e : g[u]) {
int v = e.first, w = e.second;
if (vis[v] || v == f) continue;
d[v] = d[u] + w, dep[v] = dep[u] + 1;
dfs_ans(v, u, ans);
}
}
void dfs_minn(int u, int f, int val) {
if (d[u] <= k) minn[d[u]] = (val == 1 ? min(minn[d[u]], dep[u]) : inf);
for (auto e : g[u]) {
int v = e.first;
if (vis[v] || v == f) continue;
dfs_minn(v, u, val);
}
}
int work(int u, int f, int tot) {
dfs_rt(u, f, tot);
u = rt, vis[u] = 1, d[u] = 0, dep[u] = 0, minn[0] = 0;
int ans = inf;
for (auto e : g[u]) {
int v = e.first, w = e.second;
if (vis[v]) continue;
cnt = 0, d[v] = w, dep[v] = 1;
dfs_ans(v, u, ans);
sz[v] = cnt;
dfs_minn(v, u, 1);
}
dfs_minn(u, f, -1);
for (auto e : g[u]) {
int v = e.first;
if (vis[v]) continue;
ans = min(ans, work(v, u, sz[v]));
}
return ans;
}
int main() {
scanf("%d%d", &n, &k);
rep(i, 1, k) minn[i] = inf;
rep(i, 1, n - 1) {
int u, v, w;
scanf("%d%d%d", &u, &v, &w), ++u, ++v;
g[u].push_back({v, w});
g[v].push_back({u, w});
}
int ans = work(1, 0, n);
printf("%d\n", ans == inf ? -1 : ans);
}