##题目描述
大意:给一颗
n
n
n个点的有根树,每个点有一个期待的数值
w
[
i
]
w[i]
w[i]。
m
m
m次操作,每次往
s
s
s 到
t
t
t 的路径上插入一个首项为
0
0
0公差为
1
1
1的等差数列。结束后问每个点处有多少个值等于这个点期待的数值。
链接在这里
##分析
我们来观察一下插入操作。考虑插入操作能够对哪些点产生影响,首先,需要在这个路径上,且期望值需要和等差数列对应项相同。前一个限制没有什么好说的,注意后一个限制。容易想到将这个路径分成两部分考虑,向上的(
s
s
s到
l
c
a
lca
lca)和向下的(
l
c
a
lca
lca到
t
t
t)。
向上走的链是否有什么性质呢?每向上走一个,等差数列值
+
1
+1
+1,深度
−
1
-1
−1,于是深度和等差数列值之和为定值,且等于起点的深度。于是对于一个向上走的链,当且仅当一个链上的点
u
u
u满足以下等式,该链会对他产生贡献:
d
e
p
[
u
]
+
w
[
u
]
=
d
e
p
[
s
]
dep[u]+w[u]=dep[s]
dep[u]+w[u]=dep[s]。于是我们可以对点进行分类,
d
e
p
[
u
]
+
w
[
u
]
dep[u]+w[u]
dep[u]+w[u]相同的在一类中。于是每此操作我们只需在对应类别中找到在该链上的,并对其
+
1
+1
+1即可。处理一条链上的点,树链剖分是容易想到的。于是我们对每一个
d
e
p
[
u
]
+
w
[
u
]
dep[u]+w[u]
dep[u]+w[u]的值开一个线段树,由于空间限制,动态开点即可。每次在
d
e
p
[
u
]
+
w
[
u
]
=
d
e
p
[
s
]
dep[u]+w[u] = dep[s]
dep[u]+w[u]=dep[s]处找到对应链并
+
1
+1
+1。
现在考虑向下走的链。方向变化之后,每次向下一层,等差数列值也加一。于是等差数列值与深度的差是定值,且等于
l
c
a
lca
lca处的值,即为
(
d
e
p
[
s
]
−
d
e
p
[
l
c
a
]
)
−
d
e
p
[
l
c
a
]
=
d
e
p
[
s
]
−
2
∗
d
e
p
[
l
c
a
]
(dep[s]-dep[lca])-dep[lca] = dep[s] - 2*dep[lca]
(dep[s]−dep[lca])−dep[lca]=dep[s]−2∗dep[lca]。所以当点
u
u
u满足如下式子时,可以对其产生贡献:
w
[
u
]
−
d
e
p
[
u
]
=
(
d
e
p
[
s
]
−
2
∗
d
e
p
[
l
c
a
]
)
w[u]-dep[u]=(dep[s]-2*dep[lca])
w[u]−dep[u]=(dep[s]−2∗dep[lca])。那么按照
w
[
u
]
−
d
e
p
[
u
]
w[u]-dep[u]
w[u]−dep[u]也进行分类,开若干颗线段树,运用树剖在对应类里修改。为了避免出现负数,我们把所有
w
[
u
]
−
d
e
p
[
u
]
w[u]-dep[u]
w[u]−dep[u]加上一个n即可。
现在考察一些细节。树剖修改的部分如何能正确区分这是向上走还是向下走呢?一种方式是分别从
s
s
s跳到
l
c
a
lca
lca找寻上升那条链的线段树,从
t
t
t跳到
l
c
a
lca
lca找寻下降那条链的线段树,然后最后删去一次lca(因为这样算了两遍)。另一种方法,也是我采取的方法是这样的。存下向上跳时要修改的线段树根节点编号,记为
R
T
1
RT1
RT1,和向下走时要修改的线段树根节点编号,记为
R
T
2
RT2
RT2。在树剖中
s
w
a
p
(
u
,
v
)
swap(u, v)
swap(u,v)时,同时
s
w
a
p
(
R
T
1
,
R
T
2
)
swap(RT1, RT2)
swap(RT1,RT2)。这在前面跳的过程中是显然正确的。现在考虑最后一段。如果当前我们是最后一次跳到的
t
o
p
top
top比较高,那么这时应该去跳另一侧的,此时要
s
w
a
p
(
u
,
v
)
,
s
w
a
p
(
R
T
1
,
R
T
2
)
swap(u, v), swap(RT1, RT2)
swap(u,v),swap(RT1,RT2);否则,我们仍然应该跳这一侧的,不必交换。于是我们便解决了本题。
##代码
#include <iostream>
#include <cstdio>
#include <bits/stdc++.h>
#include <algorithm>
#include <cstdlib>
#include <queue>
#include <set>
#define MAXN 300050
#define ri register int
using namespace std;
int n, m, ecnt, tcnt, idcnt, w[MAXN], ans[MAXN];
int fa[MAXN], dep[MAXN], top[MAXN], size[MAXN], son[MAXN], id[MAXN], rt[MAXN<<2], rt2[MAXN<<2];
struct node {
int v;
node *next;
}pool[MAXN<<2], *h[MAXN];
struct Node {
int ls, rs, lazy;
}t[MAXN<<5];
inline void adde(int u, int v) {
node *p = &pool[ecnt++], *q = &pool[ecnt++];
*p = node {v, h[u]}, h[u] = p;
*q = node {u, h[v]}, h[v] = q;
}
void ins(int &u, int l, int r, int x) {
if(!u) u = ++tcnt;
if(l == r) return ;
int mid = (l+r)>>1;
if(x <= mid) ins(t[u].ls, l, mid, x);
else ins(t[u].rs, mid+1, r, x);
}
void dfs1(int u) {
size[u] = 1;
for(node *p = h[u]; p; p = p->next) {
if(p->v != fa[u]) {
fa[p->v] = u, dep[p->v] = dep[u]+1;
dfs1(p->v);
size[u] += size[p->v];
if(size[son[u]] <= size[p->v]) son[u] = p->v;
}
}
}
void dfs2(int u, int t) {
top[u] = t;
id[u] = ++idcnt;
ins(rt[dep[u]+w[u]], 1, n, id[u]);
ins(rt2[w[u]-dep[u]+n], 1, n, id[u]);
if(!son[u]) return ;
dfs2(son[u], t);
for(node *p = h[u]; p; p = p->next)
if(!id[p->v]) dfs2(p->v, p->v);
}
void change(int u, int l, int r, int tl, int tr, int w) {
if(!u) return ;
if(tl <= l && r <= tr) return (void)(t[u].lazy += w);
int mid = (l+r)>>1;
if(tl <= mid) change(t[u].ls, l, mid, tl, tr, w);
if(mid < tr) change(t[u].rs, mid+1, r, tl, tr, w);
}
int Get(int u, int v) {
while(top[u] != top[v]) {
if(dep[top[u]] < dep[top[v]]) swap(u, v);
u = fa[top[u]];
}
if(id[u] > id[v]) swap(u, v);
return u;
}
void Change(int u, int v) {
int lca = Get(u, v);
int RT1 = rt[dep[u]], RT2 = rt2[dep[u]-2*dep[lca]+n];
while(top[u] != top[v]) {
if(dep[top[u]] < dep[top[v]]) swap(u, v), swap(RT1, RT2);
change(RT1, 1, n, id[top[u]], id[u], 1);
u = fa[top[u]];
}
if(id[v] > id[u]) swap(u, v), swap(RT1, RT2);
change(RT1, 1, n, id[v], id[u], 1);
}
int query(int u, int l, int r, int x) {
if(l == r) return t[u].lazy;
int mid = (l+r)>>1;
if(t[u].lazy)
t[t[u].ls].lazy += t[u].lazy, t[t[u].rs].lazy += t[u].lazy, t[u].lazy = 0;
if(x <= mid) return query(t[u].ls, l, mid, x);
return query(t[u].rs, mid+1, r, x);
}
int main () {
int u, v;
scanf("%d%d", &n, &m);
for(ri i = 1; i < n; ++i) scanf("%d%d", &u, &v), adde(u, v);
for(ri i = 1; i <= n; ++i) scanf("%d", &w[i]);
dfs1(1), dfs2(1, 1);
while(m--) {
scanf("%d%d", &u, &v);
Change(u, v);
}
for(ri i = 1; i <= n; ++i) {
ans[i] = query(rt[dep[i]+w[i]], 1, n, id[i]);
ans[i] += query(rt2[w[i]-dep[i]+n], 1, n, id[i]);
printf("%d ", ans[i]);
}
return 0;
}