P1600 [NOIP2016 提高组] 天天爱跑步
给定一颗有 n n n个点的树,有 m m m个人在树上移动,第 i i i个人从 s i s_i si点,移动到 t i t_i ti点,且他们按照最短路移动,每秒移动一条边的距离,
点 i i i在 w i w_i wi时刻有一个观察员,我们需要对每个点统计,在 w i w_i wi时刻有多少个人恰好到达这个点。
如果第 v v v个人,在 w u w_u wu时刻恰好出现在点 u u u,则一定有 d i s ( s v , u ) = w u dis(s_v, u) = w_u dis(sv,u)=wu,且 u u u在 s v , t v s_v, t_v sv,tv的路径上,
满足 d i s ( s , u ) + d i s ( u , t ) = d i s ( s , t ) dis(s, u) + dis(u, t) = dis(s, t) dis(s,u)+dis(u,t)=dis(s,t),假设 l c a ( s , t ) = z lca(s, t) = z lca(s,t)=z分两种情况讨论,以 1 1 1号节点为根, d ( i ) d(i) d(i)表示第 i i i号节点的深度,
-
u u u在 s − > z s->z s−>z的路径上,则有 d ( s ) − d ( u ) + d ( u ) + d ( t ) − 2 × d ( z ) = d ( s ) + d ( z ) − 2 × d ( z ) d(s) - d(u) + d(u) + d(t) - 2 \times d(z) = d(s) + d(z) - 2 \times d(z) d(s)−d(u)+d(u)+d(t)−2×d(z)=d(s)+d(z)−2×d(z),且 d ( s ) − d ( u ) = w u d(s) - d(u) = w_u d(s)−d(u)=wu。
-
u u u在 t − > z t->z t−>z的路径上,则有 d ( s ) + d ( u ) − 2 × d ( z ) + d ( t ) − d ( z ) = d ( s ) + d ( z ) − 2 × d ( z ) d(s) + d(u) - 2 \times d(z) + d(t) - d(z) = d(s) + d(z) - 2 \times d(z) d(s)+d(u)−2×d(z)+d(t)−d(z)=d(s)+d(z)−2×d(z),且 d ( s ) + d ( u ) − 2 × d ( z ) = w u d(s) + d(u) - 2 \times d(z) = w_u d(s)+d(u)−2×d(z)=wu。
前项都是符合要求的,所以看后面的两项 d ( s ) = d ( u ) + w u d(s) = d(u) + w_u d(s)=d(u)+wu, d ( u ) − w = 2 × d ( z ) − d ( s ) d(u) - w = 2 \times d(z) - d(s) d(u)−w=2×d(z)−d(s)。
考虑树上差分,在 s s s点插入 d ( s ) d(s) d(s),在 t t t点插入 2 × d ( z ) − d ( s ) 2 \times d(z) - d(s) 2×d(z)−d(s),
在 l c a lca lca处减去 d ( s ) d(s) d(s)的值,在 f a ( l c a ) fa(lca) fa(lca)处减去 2 × d ( z ) − d ( s ) 2 \times d(z) - d(s) 2×d(z)−d(s)的值,二者可互换顺序,之后只要线段树合并,再单点查询值即可。
#include <bits/stdc++.h>
using namespace std;
const int N = 3e5 + 10, maxn = 300000;
int head[N], to[N << 1], nex[N << 1], cnt = 1;
int dep[N], fa[N], son[N], sz[N], top[N];
int w[N], ans[N], n, m;
int root[N], ls[N << 5], rs[N << 5], sum[N << 5], num;
vector<pair<int, int>> a[N];
void add(int x, int y) {
to[cnt] = y;
nex[cnt] = head[x];
head[x] = cnt++;
}
void dfs1(int rt, int f) {
fa[rt] = f, dep[rt] = dep[f] + 1, sz[rt] = 1;
for (int i = head[rt]; i; i = nex[i]) {
if (to[i] == f) {
continue;
}
dfs1(to[i], rt);
sz[rt] += sz[to[i]];
if (!son[rt] || sz[to[i]] > sz[son[rt]]) {
son[rt] = to[i];
}
}
}
void dfs2(int rt, int tp) {
top[rt] = tp;
if (!son[rt]) {
return ;
}
dfs2(son[rt], tp);
for (int i = head[rt]; i; i = nex[i]) {
if (to[i] == son[rt] || to[i] == fa[rt]) {
continue;
}
dfs2(to[i], to[i]);
}
}
int lca(int u, int v) {
while (top[u] != top[v]) {
if (dep[top[u]] < dep[top[v]]) {
swap(u, v);
}
u = fa[top[u]];
}
return dep[u] < dep[v] ? u : v;
}
void update(int &rt, int l, int r, int x, int v) {
if (!rt) {
rt = ++num;
}
sum[rt] += v;
if (l == r) {
return ;
}
int mid = l + r >> 1;
if (x <= mid) {
update(ls[rt], l, mid, x, v);
}
else {
update(rs[rt], mid + 1, r, x, v);
}
}
int query(int rt, int l, int r, int x) {
if (l == r) {
return sum[rt];
}
int mid = l + r >> 1;
if (x <= mid) {
return query(ls[rt], l, mid, x);
}
else {
return query(rs[rt], mid + 1, r, x);
}
}
int merge(int x, int y, int l, int r) {
if (!x || !y) {
return x | y;
}
if (l == r) {
sum[x] += sum[y];
return x;
}
int mid = l + r >> 1;
ls[x] = merge(ls[x], ls[y], l, mid);
rs[x] = merge(rs[x], rs[y], mid + 1, r);
sum[x] = sum[ls[x]] + sum[rs[x]];
return x;
}
void dfs(int rt, int fa) {
for (auto it : a[rt]) {
update(root[rt], -maxn, maxn, it.first, it.second);
}
for (int i = head[rt]; i; i = nex[i]) {
if (to[i] == fa) {
continue;
}
dfs(to[i], rt);
root[rt] = merge(root[rt], root[to[i]], -maxn, maxn);
}
ans[rt] = query(root[rt], -maxn, maxn, dep[rt] + w[rt]);
if (w[rt]) {
ans[rt] += query(root[rt], -maxn, maxn, dep[rt] - w[rt]);
}
}
int main() {
// freopen("in.txt", "r", stdin);
// freopen("out.txt", "w", stdout);
scanf("%d %d", &n, &m);
for (int i = 1, x, y; i < n; i++) {
scanf("%d %d", &x, &y);
add(x, y);
add(y, x);
}
dfs1(1, 0);
dfs2(1, 1);
for (int i = 1; i <= n; i++) {
scanf("%d", &w[i]);
}
for (int i = 1, s, t; i <= m; i++) {
scanf("%d %d", &s, &t);
int f = lca(s, t), ff = fa[f];
a[s].push_back({dep[s], 1});
a[t].push_back({2 * dep[f] - dep[s], 1});
a[f].push_back({dep[s], -1});
a[ff].push_back({2 * dep[f] - dep[s], -1});
}
dfs(1, 0);
for (int i = 1; i <= n; i++) {
printf("%d%c", ans[i], i == n ? '\n' : ' ');
}
return 0;
}