这题我一看的时候是没有什么思路的。题目意思很明确,但是不知道如何高效地求出来。
看到数据范围,感觉 40 分的暴力应该是可以的。
但是写暴力也要有技巧。经过我一番慎重的思考和比较,发现从“当前节点能控制多少后代”和“当前节点能被多少祖先控制”两个角度考虑是不一样的。
虽然题目的设问是从前者的角度,但是每个节点可能会有很多后代,而每个结点只有唯一的父亲,唯一的父亲的父亲……
有了这个前提,我就很快想到了一种简单粗暴的方法:将每个节点作为
v
考虑,各跑一遍 dfs,记一个参数
也就是说,题目要求的是
每次去向
v′
的父亲
u
时,有
也就是每次递归的时候更新参数只需减去当前节点与父亲之间的边权即可。这种方法编程十分容易,我只写了 10 分钟不到,代码 31 行,拿了 50 分。
代码:
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
using namespace std;
const int MAXN = 1e5 + 100;
int N, M;
int a[MAXN], par[MAXN], ans[MAXN];
long long len[MAXN]; //注意数据范围,答案可能很大,建议使用 long long 类型
void dfs(int root, long long dis) {
if (dis < 0 || !root) return; //边界条件
++ans[root]; //当前的 i 可以被 root 控制
dfs(par[root], dis - len[root]); //剩余距离减去相应的边权
}
int main(void) {
freopen("2095.in", "r", stdin);
freopen("2095.out", "w", stdout);
scanf("%d", &N);
for (int i = 1; i <= N; i++) scanf("%d", &a[i]);
for (int i = 2; i <= N; i++) scanf("%d%lld", &par[i], &len[i]);
for (int i = 2; i <= N; i++) dfs(par[i], a[i] - len[i]);
for (int i = 1; i <= N; i++) printf("%d ", ans[i]);
return 0;
}
但是,我们当然不能就此止步。
上面这个想法,“考虑后代对祖先的贡献”其实是非常好的,反题目之道而行之,而正解也是要从这一步入手。
首先分析一下,上面的算法为什么会慢,其实关键就在 dfs 部分,一步一步往上找,太浪费时间。
我们知道,在树上,每个节点
i
到根节点的距离是易求的,跑一遍 dfs 就可以算出来了,不妨记为
而在控制关系中,
u
为
现在不难发现,对于所枚举的一个确定的
v
,
则任务可转化为:找到满足
sumu
大于等于这个定值且深度最小的 u。求得后,
u
到
通过计算还可以发现,对于一条从祖先到后代的路径上的节点
i
,
二分!
现在来看要找的要求,不是正好满足求一个 lower_bound 的特征嘛!
如何实现二分呢?有两种方法。
第一种是我在考试的时候现场手推的倍增(以前学过,但是忘了),Ghastlcon 也用的是这种方法。关于倍增这里不再详细介绍,详细资料请自行查询。
记
在二分的时候
l,r
组成树上一条祖先到后代的路径,
u
在这个路径上,位置待定。而这条路径的中点
如果
summid
小于之前所说的定值,说明
mid
在太高的位置了,已经控制不到
v
,因此
最后就可以求得满足条件且深度最小的 u 。之后怎么做,暂时先不说。
第二种方法我感觉比较巧妙,是 aspe 提出的,不用预处理,只需借助一个栈。
不用直接按编号枚举
但现在我们也仅仅只是对于每一个
显然,如果再去一步一步往上,给这条路径上的点加答案,那么这一步最坏情况都要花
这里就要用到差分思想。
回顾这样一个问题:在一个数轴上,每次指定一段 [li,ri] ,给这个区间内的数加上 1,最后询问某个点的值。相信很多人都会想到在 li 加 1,在 ri+1 的位置减 1,从左往右做一遍关于这些标记的前缀和,就可以了。这样一来就把对区间的操作转化成对端点的操作,其实就是差分思想。具体的应用,可以结合这篇文章的讲解,并参考 GCOI2015 小学六年级组的“计时器”一题。
现在回顾到本题,其实也可以在树上做差分(事实上,这还将会是一种相当常用且好用的技巧)。可以把一条条路径就想象为数轴上的一个个区间,就可以通过在
v
的父亲结点打一个 +1 标记,在
最后跑的这遍 dfs,先递归到叶子结点,则它们的
ansi
都为 0(显然它们没有后代,不可能控制其他节点)。而对于非叶节点,它的
ansi
就等于它所有儿子的
ansj
之和(想象一下,相当于
i
延续着一条从下面一直到
到这里,本题就被完美解决了。时间复杂度(我用的是方法一)为 O(nloglogn) 。
参考代码:
#include <algorithm>
#include <cmath>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
using namespace std;
const int MAXN = 1e5 + 100;
struct EDGE { int to, next; long long cost; } edges[MAXN << 1]; //此处其实开一倍足矣,但在保证不 MLE 的前提下开两倍更保险。
int N, M;
int maxd, dep[MAXN], upper_lim;
long long a[MAXN], par[MAXN][20], len[MAXN], sum[MAXN], mark[MAXN], ans[MAXN];
int head[MAXN];
void add_edge(int u, int v, long long w) {
edges[M++] = (EDGE){v, head[u], w};
head[u] = M - 1;
}
void dfs1(int pre, int root, long long dis) { //负责求出各节点到根节点的距离并统计深度
sum[root] = dis; maxd = max(maxd, dep[root]);
for (int i = head[root]; i != -1; i = edges[i].next) {
int v = edges[i].to;
if (v != pre) {
dep[v] = dep[root] + 1;
dfs1(root, v, dis + len[v]);
}
}
}
int query(int r, int d) { //倍增思想,借助 par 数组询问节点 r 的第 d 代父亲
for (int p = 0; d; p++, d >>= 1) if (d & 1) r = par[r][p];
return r;
}
void solve(int cur) {
if (len[cur] > a[cur]) return; //连父亲都控制不了自己
int l = 0, r = par[cur][0]; //(l, r]
while (dep[l] + 1 < dep[r]) {
int mid = query(r, dep[r] - dep[l] >> 1);
if (sum[cur] - sum[mid] <= a[cur]) r = mid; else l = mid;
//这里我在比赛的时候犯了一个致命的错误,原先我的 l,r 表示的是 u 在 cur 的第 2^l 代父亲和第 2^r 父亲之间
//后来直接用 l 和 r 表示具体节点,结果忘记把 sum[par[cur][mid]] 改成 sum[mid] 了,浪费了大量时间进行调试
}
++mark[par[cur][0]]; //从节点 cur 的父亲开始可以控制 cur
--mark[par[r][0]]; //到节点 r 的父亲处停止控制
}
void dfs2(int pre, int root) {
ans[root] = mark[root]; //自身标记
for (int i = head[root]; i != -1; i = edges[i].next) {
int v = edges[i].to;
if (v != pre) {
dfs2(root, v);
ans[root] += ans[v]; //各儿子答案之和
}
}
}
int main(void) {
freopen("2095.in", "r", stdin);
freopen("2095.out", "w", stdout);
scanf("%d", &N);
for (int i = 1; i <= N; i++) scanf("%lld", &a[i]);
memset(head, -1, sizeof head);
for (int i = 2; i <= N; i++) {
scanf("%d%lld", &par[i][0], &len[i]);
add_edge(par[i][0], i, len[i]); //可以通过 par[i][0] 直接去向父亲,因此只需连一条父亲到儿子的边。
}
dep[0] = -1; dfs1(0, 1, 0);
upper_lim = log2(maxd);
for (int i = 1; i <= upper_lim; i++) //预处理部分
for (int j = 2; j <= N; j++)
par[j][i] = par[par[j][i - 1]][i - 1];
/*
for (int i = 2; i <= N; i++) {
for (int j = 0; j <= upper_lim; j++) printf("%d ", par[i][j]);
putchar('\n');
}
*/
for (int i = 2; i <= N; i++) solve(i);
dfs2(0, 1);
for (int i = 1; i <= N; i++) printf("%lld ", ans[i]);
return 0;
}