B. Alyona and a tree(dsu on tree + bit)
给定一颗以 1 1 1号节点为根的树,每个点有点权 a i a_i ai,边有边权,如果 v v v控制了点 u u u,当且仅当 u u u是 v v v的子树中的节点且 d i s ( u , v ) ≤ a u dis(u, v) \leq a_u dis(u,v)≤au,
我们定义 d ( u ) d(u) d(u)为点 1 1 1到点 u u u距离,则对于某个点 v v v来说我们就是要在其字数上找 d ( u ) − d ( v ) ≤ a u d(u) - d(v) \leq a_u d(u)−d(v)≤au, d ( u ) − a u ≤ d ( v ) d(u) - a_u \leq d(v) d(u)−au≤d(v),
对所有的 d ( u ) − a u , d ( u ) d(u) - a_u, d(u) d(u)−au,d(u)进行离散化,就可以考虑树上启发式合并 + 树状数组来完成上述操作,整体复杂度 O ( n log n log n ) O (n \log n \log n) O(nlognlogn)。
#include <bits/stdc++.h>
using namespace std;
const int N = 2e5 + 10;
int head[N], to[N], nex[N], value[N], cnt = 1;
int a[N], ans[N], sz[N], son[N], l[N], r[N], id[N], sum[N << 3], tot, n, m;
long long d[N], b[N << 1];
inline int lowbit(int x) {
return x & -x;
}
void add(int x, int y, int w) {
to[cnt] = y;
nex[cnt] = head[x];
value[cnt] = w;
head[x] = cnt++;
}
void dfs(int rt, int fa) {
sz[rt] = 1, l[rt] = ++tot, id[tot] = rt;
for (int i = head[rt]; i; i = nex[i]) {
if (to[i] == fa) {
continue;
}
d[to[i]] = d[rt] + value[i];
dfs(to[i], rt);
sz[rt] += sz[to[i]];
if (!son[rt] || sz[to[i]] > sz[son[rt]]) {
son[rt] = to[i];
}
}
r[rt] = tot;
}
void update(int x, int v) {
while (x <= m) {
sum[x] += v;
x += lowbit(x);
}
}
int query(int x) {
int ans = 0;
while (x) {
ans += sum[x];
x -= lowbit(x);
}
return ans;
}
void dfs(int rt, int fa, bool keep) {
for (int i = head[rt]; i; i = nex[i]) {
if (to[i] == fa || to[i] == son[rt]) {
continue;
}
dfs(to[i], rt, 0);
}
if (son[rt]) {
dfs(son[rt], rt, 1);
}
for (int i = head[rt]; i; i = nex[i]) {
if (to[i] == fa || to[i] == son[rt]) {
continue;
}
for (int j = l[to[i]]; j <= r[to[i]]; j++) {
update(d[id[j]], 1);
}
}
ans[rt] = query(a[rt]);
update(d[rt], 1);
if (!keep) {
for (int i = l[rt]; i <= r[rt]; i++) {
update(d[id[i]], -1);
}
}
}
int main() {
// freopen("in.txt", "r", stdin);
// freopen("out.txt", "w", stdout);
scanf("%d", &n);
for (int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
}
for (int i = 2, x, w; i <= n; i++) {
scanf("%d %d", &x, &w);
add(x, i, w);
}
dfs(1, 0);
for (int i = 1; i <= n; i++) {
b[++m] = d[i], b[++m] = d[i] - a[i];
}
sort(b + 1, b + 1 + m);
m = unique(b + 1, b + 1 + m) - (b + 1);
for (int i = 1; i <= n; i++) {
int temp = a[i];
a[i] = lower_bound(b + 1, b + 1 + m, d[i]) - b;
d[i] = lower_bound(b + 1, b + 1 + m, d[i] - temp) - b;
}
dfs(1, 0, 1);
for (int i = 1; i <= n; i++) {
printf("%d%c", ans[i], i == n ? '\n' : ' ');
}
return 0;
}