【模板】树上 k 级祖先
题目链接:luogu P5903
题目大意
给你一棵树,要你在线 O(1) 求一个点的 k 级祖先。
思路
这个我们可以用长链剖分来做,从而可以做到预处理 O ( n log n ) O(n\log n) O(nlogn) 询问 O ( 1 ) O(1) O(1)。
首先预处理:
最数进行长链剖分(就是找最长的一条链,可以通过深度来看),也是记录对应儿子。(
O
(
n
)
O(n)
O(n))
然后倍增求
2
k
2^k
2k 级父亲。(
O
(
n
log
n
)
O(n\log n)
O(nlogn))
然后每条长链,如果长度为
l
e
n
len
len,我们就求出从它的顶点出发,它向上的
l
e
n
len
len 个祖先和向下走长链的
l
e
n
len
len 个儿子。(
O
(
n
)
O(n)
O(n))
然后再求每个数二进制最高位
g
k
g_k
gk。(
O
(
n
)
O(n)
O(n))
然后看询问怎么回答:
首先用倍增数组跳到
k
k
k 的最高位级父亲,那假设剩下
k
′
k'
k′ 级,那首先由
k
′
<
2
g
k
k'<2^{g_k}
k′<2gk,而且因为是长链,所以
x
x
x 所在的长链的长度一定
⩾
2
g
k
>
k
′
\geqslant 2^{g_k}>k'
⩾2gk>k′。
那根据
l
e
n
len
len(长链长度)
>
k
′
>k'
>k′,所以我们可以跳到链的顶点,然后看
k
′
−
l
e
n
=
k
′
′
k'-len=k''
k′−len=k′′:
如果
k
′
′
>
0
k''>0
k′′>0,说明要的点还在上面,就用上面的。
如果
k
′
′
<
0
k''<0
k′′<0,说明要的点在下面,就用下面的。
(这两个就是
n
log
n
n\log n
nlogn 预处理出来的那个)
然后就可以了。
代码
#include<cstdio>
#include<vector>
#define ll long long
using namespace std;
const int N = 5e5 + 100;
int n, q, f[N][21], rt, lstans;
int deg[N], d[N], son[N], top[N], g[N];
vector <int> G[N], up[N], down[N];
ll ans;
#define ui unsigned int
ui s;
inline ui get(ui x) {
x ^= x << 13;
x ^= x >> 17;
x ^= x << 5;
return s = x;
}
void dfs0(int now, int father) {
d[now] = d[father] + 1; deg[now] = d[now];
for (int i = 1; i <= 20; i++) f[now][i] = f[f[now][i - 1]][i - 1];
for (int i = 0; i < G[now].size(); i++) {
int x = G[now][i]; dfs0(x, now);
if (deg[x] > deg[now]) deg[now] = deg[x], son[now] = x;
}
}
void dfs1(int now, int father) {
if (now == top[now]) {
for (int i = 0, x = now; i <= deg[now] - d[now]; i++)
up[now].push_back(x), x = f[x][0];
for (int i = 0, x = now; i <= deg[now] - d[now]; i++)
down[now].push_back(x), x = son[x];
}
if (son[now]) top[son[now]] = top[now], dfs1(son[now], now);
for (int i = 0; i < G[now].size(); i++) {
int x = G[now][i]; if (x == son[now]) continue;
top[x] = x; dfs1(x, now);
}
}
int Jump(int x, int k) {
if (!k) return x;
x = f[x][g[k]]; k -= (1 << g[k]);//先跳最高的 1
k -= d[x] - d[top[x]]; x = top[x];//跳到这个长链的顶部
return k >= 0 ? up[x][k] : down[x][-k];//直接找
}
int main() {
scanf("%d %d %u", &n, &q, &s); g[0] = -1;
for (int i = 1; i <= n; i++) {
scanf("%d", &f[i][0]); if (!f[i][0]) rt = i;
G[f[i][0]].push_back(i); g[i] = g[i >> 1] + 1;
}
dfs0(rt, 0); top[rt] = rt; dfs1(rt, 0);
for (int qq = 1; qq <= q; qq++) {
int x = ((get(s) ^ lstans) % n) + 1, k = (get(s) ^ lstans) % d[x];
lstans = Jump(x, k); ans ^= 1ll * qq * lstans;
}
printf("%lld", ans);
return 0;
}