题意
题解
点分治
以 p p p 为根节点的树上路径可分为两类:经过 p p p 的路径;不经过 p p p 的路径。那么可以在处理 p p p 时处理第一类路径,再递归子树处理第二类路径。
路径权值和 k k k 范围较小,使用桶 m n mn mn 记录节点至根节点的路径中权值和为 [ 0 , k ] [0,k] [0,k] 的最小深度(路径数),依次遍历子树,以 D F S DFS DFS 时间戳为索引记录子树上节点至根节点的深度 d e p dep dep 与距离 d s ds ds,并进行可行性与最优性剪枝,利用之前遍历的子树更新的桶 m n mn mn 来更新答案 min { d e p [ i ] + m n [ d s [ i ] ] } \min\{dep[i]+mn[ds[i]]\} min{dep[i]+mn[ds[i]]},此时保证经过根节点的两条子路径不属于同一颗子树。最后使用之前记录的时间戳清除更新的信息。每一层时间复杂度 O ( N ) O(N) O(N)。
在树上点分治,每次选取重心进行分解,递归 O ( log N ) O(\log N) O(logN) 层。总时间复杂度 O ( N log N ) O(N\log N) O(NlogN)。
#include <bits/stdc++.h>
using namespace std;
const int maxn = 200005, maxe = maxn << 1, maxk = 1000005, inf = 0x3f3f3f3f;
int N, K, res, tot, dep[maxn], ds[maxn], mn[maxk];
int mx, w, sz[maxn];
int E, head[maxn], to[maxe], nxt[maxe], cost[maxe];
bool del[maxn];
inline int read()
{
int x = 0;
char c = 0;
for (; c < '0' || c > '9'; c = getchar())
;
for (; c >= '0' && c <= '9'; c = getchar())
x = (x << 1) + (x << 3) + c - '0';
return x;
}
void add(int x, int y, int z) { to[++E] = y, cost[E] = z, nxt[E] = head[x], head[x] = E; }
void find(int n, int x, int f)
{
sz[x] = 1;
int m = 0;
for (int i = head[x]; i; i = nxt[i])
{
int y = to[i];
if (y != f && !del[y])
find(n, y, x), sz[x] += sz[y], m = max(m, sz[y]);
}
m = max(m, n - sz[x]);
if (m < mx)
mx = m, w = x;
}
void get_sz(int x, int f)
{
sz[x] = 1;
for (int i = head[x]; i; i = nxt[i])
{
int y = to[i];
if (y != f && !del[y])
get_sz(y, x), sz[x] += sz[y];
}
}
void get_d(int x, int f, int d, int c)
{
if (c > K || d >= res)
return;
ds[++tot] = c, dep[tot] = d;
for (int i = head[x]; i; i = nxt[i])
{
int y = to[i], z = cost[i];
if (y != f && !del[y])
get_d(y, x, d + 1, c + z);
}
}
void solve(int n, int x)
{
mx = n, w = -1;
find(n, x, -1);
ds[tot = 1] = 0, mn[0] = 0;
for (int i = head[w], pre = 2, cur; i; i = nxt[i])
{
int y = to[i], z = cost[i];
if (!del[y])
{
get_d(y, w, 1, z), get_sz(y, w), cur = tot + 1;
for (int j = pre; j < cur; ++j)
res = min(res, dep[j] + mn[K - ds[j]]);
for (int j = pre; j < cur; ++j)
mn[ds[j]] = min(mn[ds[j]], dep[j]);
pre = cur;
}
}
for (int i = 1; i <= tot; ++i)
mn[ds[i]] = inf;
del[w] = 1;
for (int i = head[w]; i; i = nxt[i])
{
int y = to[i];
if (!del[y])
solve(sz[y], y);
}
}
int main()
{
N = read(), K = read();
for (int i = 1, x, y, z; i < N; ++i)
x = read(), y = read(), z = read(), add(x, y, z), add(y, x, z);
res = maxn;
memset(mn, 0x3f, sizeof(int) * (K + 1));
solve(N, 0);
printf("%d\n", res == maxn ? -1 : res);
return 0;
}
dsu on tree
设根节点为
x
x
x,树上节点到根节点深度、距离分别为
d
e
p
,
d
s
dep,ds
dep,ds。对于树上两个节点
y
,
z
y,z
y,z,若其路径权值和为
k
k
k,则满足
(
d
s
[
y
]
−
d
s
[
x
]
)
+
(
d
s
[
z
]
−
d
s
[
x
]
)
=
k
(ds[y]-ds[x])+(ds[z]-ds[x])=k
(ds[y]−ds[x])+(ds[z]−ds[x])=k 即已知
d
s
[
y
]
ds[y]
ds[y],则
d
s
[
z
]
=
k
+
2
d
s
[
x
]
−
d
s
[
y
]
ds[z]=k+2ds[x]-ds[y]
ds[z]=k+2ds[x]−ds[y],此时路径上边数为
(
d
e
p
[
y
]
−
d
e
p
[
x
]
)
+
(
d
e
p
[
z
]
−
d
e
p
[
x
]
)
(dep[y]-dep[x])+(dep[z]-dep[x])
(dep[y]−dep[x])+(dep[z]−dep[x]) 预处理出重儿子(节点数最多的儿子)的同时计算各节点到根节点的距离与深度,使用哈希表记录节点距离到最小节点深度的映射,使用树上启发式合并更新答案。总时间复杂度
O
(
N
log
N
)
O(N\log N)
O(NlogN)。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 200005, maxe = maxn << 1, inf = 0x3f3f3f3f;
int N, K, res, hs[maxn], sz[maxn], dep[maxn];
ll ds[maxn];
int E, head[maxn], to[maxe], nxt[maxe], cost[maxe];
unordered_map<ll, int> mn;
inline int read()
{
int x = 0;
char c = 0;
for (; c < '0' || c > '9'; c = getchar())
;
for (; c >= '0' && c <= '9'; c = getchar())
x = (x << 1) + (x << 3) + c - '0';
return x;
}
void add(int x, int y, int z) { to[++E] = y, cost[E] = z, nxt[E] = head[x], head[x] = E; }
void dfs(int x, int f, int d, ll c)
{
sz[x] = 1, dep[x] = d, ds[x] = c;
for (int i = head[x]; i; i = nxt[i])
{
int y = to[i], z = cost[i];
if (y != f)
dfs(y, x, d + 1, c + z), sz[x] += sz[y], hs[x] = sz[hs[x]] < sz[y] ? y : hs[x];
}
}
void upd(int x, int f, int k)
{
if (k < 0)
mn[ds[x]] = inf;
else
{
if (!mn.count(ds[x]))
mn[ds[x]] = dep[x];
else
mn[ds[x]] = min(mn[ds[x]], dep[x]);
}
for (int i = head[x]; i; i = nxt[i])
{
int y = to[i];
if (y != f)
upd(y, x, k);
}
}
void get_res(int x, int f, int z)
{
ll c = K + (ds[z] << 1) - ds[x];
if (mn.count(c))
res = min(res, mn[c] + dep[x] - (dep[z] << 1));
for (int i = head[x]; i; i = nxt[i])
{
int y = to[i];
if (y != f)
get_res(y, x, z);
}
}
void solve(int x, int f, int keep)
{
for (int i = head[x]; i; i = nxt[i])
{
int y = to[i];
if (y != f && y != hs[x])
solve(y, x, 0);
}
if (hs[x])
solve(hs[x], x, 1);
if (mn.count(ds[x] + K))
res = min(res, mn[ds[x] + K] - dep[x]);
mn[ds[x]] = dep[x];
for (int i = head[x]; i; i = nxt[i])
{
int y = to[i];
if (y != f && y != hs[x])
get_res(y, x, x), upd(y, x, 1);
}
if (!keep)
upd(x, f, -1);
}
int main()
{
N = read(), K = read();
for (int i = 1, x, y, z; i < N; ++i)
x = read() + 1, y = read() + 1, z = read(), add(x, y, z), add(y, x, z);
dfs(1, 0, 0, 0);
res = maxn;
solve(1, 0, 1);
printf("%d\n", res == maxn ? -1 : res);
return 0;
}