【题意】
n n n个节点的树,每个节点、每条边有权值,定义 d i s t ( u , v ) dist(u,v) dist(u,v)为路径 < u , v > <u,v> <u,v>的边权和。如果 u u u在子树 v v v内且 d i s t ( u , v ) ≤ v a l [ u ] dist(u,v) \leq val[u] dist(u,v)≤val[u],则称 v v v控制 u u u。问每个节点控制多少个点。
【分析】
不难想到对于每一个点,它被多少个点控制。
我们可以二分最远的被控点是它的第几代祖先,计算距离即可。
然后就是对我们二分出来的这条链
<
u
,
v
>
<u,v>
<u,v>的答案加一即可。
这个过程虽然可以用树链剖分+线段树实现,但代码量……憋说了我不想写
如果这个问题是在一个线性结构上,我们肯定会选择差分:开一个数组
d
d
d,对于要覆盖的区间
[
l
,
r
]
[l,r]
[l,r],将
d
[
l
]
+
1
,
d
[
r
+
1
]
−
1
d[l]+1, d[r+1]-1
d[l]+1,d[r+1]−1,然后前缀和就是对应位置的值。
我们尝试把它搬到树上。此时,我们就需要覆盖一条路径。依旧开一个数组
d
d
d,表示从根节点到当前节点的路径被覆盖了多少次,这样,我们就可以对一条路径
<
u
,
v
>
,
d
[
u
]
+
1
,
d
[
v
]
+
1
,
d
[
l
c
a
(
u
,
v
)
]
−
1
,
d
[
f
a
[
l
c
a
(
u
,
v
)
]
]
−
1
<u,v>,d[u]+1,d[v]+1,d[lca(u,v)]-1,d[fa[lca(u,v)]]-1
<u,v>,d[u]+1,d[v]+1,d[lca(u,v)]−1,d[fa[lca(u,v)]]−1,每一个节点
p
p
p的覆盖次数就是子树的d之和
+
d
[
p
]
+d[p]
+d[p]。
对应到这个题,只需让
d
[
f
a
[
u
]
]
+
1
,
d
[
f
a
[
v
]
]
−
1
d[fa[u]]+1,d[fa[v]]-1
d[fa[u]]+1,d[fa[v]]−1即可。
【代码】
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int mn = 200005;
vector<int> g[mn], v[mn];
int a[mn], f[mn][20], dis[mn], ans[mn], d[mn];
ll sum[mn][20];
void dfs(int s, int fa, int d)
{
f[s][0] = fa, dis[s] = d;
int m = g[s].size();
for(int i = 0; i < m; i++)
{
int t = g[s][i];
if(t != fa) sum[t][0] = v[s][i], dfs(t, s, d + 1);
}
}
void dfs2(int s)
{
ans[s] = 0;
int m = g[s].size();
for(int i = 0; i < m; i++)
{
int t = g[s][i];
if(t != f[s][0]) dfs2(t), ans[s] += ans[t];
}
ans[s] += d[s];
}
inline ll get_sum(int a, int k)
{
ll ret = 0; int d = 19;
while(d >= 0)
{
if((k >> d) & 1)
ret += sum[a][d], a = f[a][d];
--d;
}
return ret;
}
inline int get_ver(int a, int k)
{
int d = 19;
while(d >= 0)
{
if((k >> d) & 1)
a = f[a][d];
--d;
}
return a;
}
int main()
{
int n, x, y;
scanf("%d", &n);
for(int i = 1; i <= n; i++)
scanf("%d", &a[i]);
for(int i = 2; i <= n; i++)
scanf("%d%d", &x, &y), g[x].push_back(i), g[i].push_back(x),
v[x].push_back(y), v[i].push_back(y);
dfs(1, 0, 0);
for(int j = 1; j < 20; j++)
for(int i = 1; i <= n; i++)
f[i][j] = f[f[i][j-1]][j-1], sum[i][j] = sum[i][j-1] + sum[f[i][j-1]][j-1];
for(int i = 1; i <= n; i++)
{
int l = 1, r = dis[i], p = 0x3f3f3f3f;
while(l <= r)
{
int mid = (l + r) >> 1;
if(get_sum(i, mid) > a[i]) r = mid - 1;
else p = get_ver(i, mid), l = mid + 1;
}
if(p != 0x3f3f3f3f)
--d[f[p][0]], ++d[f[i][0]];
}
dfs2(1), printf("%d", ans[1]);
for(int i = 2; i <= n; i++)
printf(" %d", ans[i]);
puts("");
}