Ancient Distance
给定一颗根为 1 1 1有 n n n个节点的树,每次可以选定树上 k k k节点当作特殊节点,
定义 d i s ( u ) dis(u) dis(u)为,从 u − > 1 u->1 u−>1遇上的第一个特殊点的距离,如果遇不上特殊点则 d i s ( u ) dis(u) dis(u)无穷大。
有 n n n次询问,问,每次选 k ∈ { 1 , 2 , 3 , … , n − 1 , n } k \in \{1, 2, 3, \dots, n - 1, n\} k∈{1,2,3,…,n−1,n}个特殊点时的答案,
有一个性质,最大答案为 n − 1 n - 1 n−1,且 1 1 1号点是一定要选的,接下来考虑其他的点如何选取,
假设我们当前答案为 x x x,我们需要选取多少个点,有一个贪心的想法,找到一个节点最深的节点,然后把他的第 x x x代祖先设置为特殊点,
这样我们就保证了这一子树都满足答案小于等于 x x x,按照这样依次操作,最后我们的答案都会小于 x x x,
不难发现对于每个 x x x,我们所需执行的操作最多不会超过 ⌈ n x ⌉ \lceil \frac{n}{x} \rceil ⌈xn⌉,我们可以利用线段树来查询每次需要操作的点,这样保证了一次操作是 log n \log n logn的,
由此我们发现整体复杂度是 ∑ i = 1 n ⌈ n i ⌉ log n = O ( n log n log n ) \sum\limits_{i = 1} ^{n} \lceil \frac{n}{i} \rceil \log n = O(n \log n \log n) i=1∑n⌈in⌉logn=O(nlognlogn)的。
#include <bits/stdc++.h>
#define mid (l + r >> 1)
#define lson rt << 1, l, mid
#define rson rt << 1 | 1, mid + 1, r
#define ls rt << 1
#define rs rt << 1 | 1
using namespace std;
const int N = 2e5 + 10;
int maxn[N << 2], id[N << 2], cov[N << 2], ans[N], n;
int l[N], r[N], rk[N], fa[N][21], dep[N], tot;
vector<int> G[N];
void dfs(int rt, int f) {
l[rt] = ++tot, rk[tot] = rt, fa[rt][0] = f, dep[rt] = dep[f] + 1;
for (int i = 1; i <= 20; i++) {
fa[rt][i] = fa[fa[rt][i - 1]][i - 1];
}
for (int &to : G[rt]) {
if (to == f) {
continue;
}
dfs(to, rt);
}
r[rt] = tot;
}
int k_fa(int rt, int k) {
for (int i = 20; i >= 0; i--) {
if (k >> i & 1) {
rt = fa[rt][i];
}
}
return rt;
}
void push_up(int rt) {
maxn[rt] = 0;
if (!cov[ls] && maxn[ls] > maxn[rt]) {
maxn[rt] = maxn[ls];
id[rt] = id[ls];
}
if (!cov[rs] && maxn[rs] > maxn[rt]) {
maxn[rt] = maxn[rs];
id[rt] = id[rs];
}
}
void build(int rt, int l, int r) {
cov[rt] = 0;
if (l == r) {
maxn[rt] = dep[rk[l]];
id[rt] = rk[l];
return ;
}
build(lson);
build(rson);
push_up(rt);
}
void update(int rt, int l, int r, int L, int R, int v) {
if (l >= L && r <= R) {
cov[rt] = v;
return ;
}
if (L <= mid) {
update(lson, L, R, v);
}
if (R > mid) {
update(rson, L, R, v);
}
push_up(rt);
}
int main() {
// freopen("in.txt", "r", stdin);
// freopen("out.txt", "w", stdout);
while (scanf("%d", &n) != EOF) {
tot = 0;
for (int i = 1; i <= n; i++) {
G[i].clear();
}
for (int i = 2, x; i <= n; i++) {
scanf("%d", &x);
G[x].push_back(i);
G[i].push_back(x);
}
dep[0] = -1;
dfs(1, 0);
build(1, 1, n);
for (int i = 1; i <= n; i++) {
ans[i] = n;
}
vector<int> vt;
for (int cur = n - 1; cur >= 0; cur--) {
int num = 1;
vt.clear();
while (true) {
if (maxn[1] <= cur) {
break;
}
num++;
int u = k_fa(id[1], cur);
vt.push_back(u);
update(1, 1, n, l[u], r[u], 1);
}
ans[num] = cur;
for (auto rt : vt) {
update(1, 1, n, l[rt], r[rt], 0);
}
}
for (int i = 2; i <= n; i++) {
ans[i] = min(ans[i], ans[i - 1]);
}
long long res = 0;
for (int i = 1; i <= n; i++) {
res += ans[i];
}
printf("%lld\n", res);
}
return 0;
}