题意
给定 n n n 个点的有根树,根节点为 1 1 1。每个点有颜色 c i c_i ci。定义深度 dep i \text{dep}_i depi 为 i i i 节点到 1 1 1 节点边的个数。
m m m 个问题,每次给定 x , d x, d x,d,询问 x x x 的子树中所有深度在 [ dep x , dep x + d ] [\text{dep}_x, \text{dep}_x + d] [depx,depx+d] 内的点中出现了多少种不同的颜色。强制在线。
数据范围
1 ≤ n , m ≤ 1 0 5 , 1 ≤ c i ≤ n 1\le n, m\le 10^5, 1\le c_i\le n 1≤n,m≤105,1≤ci≤n。
题解
对于在序列上的问题,我们已经有了 HH的项链。在这个问题中,我们离线,向右滑动右端点,并且只关注最靠右的颜色。当将右端点滑到 r r r 时,我们将之前的颜色 r r r 的信息删除,并加入这个位置的信息。
回到现在的问题,我们仍然沿用这个方法,我们希望在一个子树中,对于每种颜色只留下深度最小的点的信息。所以在 u u u 的子树中只需要对于一种颜色 i i i 记录 d u , i d_{u,i} du,i 表示子树中所有颜色为 i i i 的点的最小深度。同时在 u u u 的子树中我们记录 x u , i x_{u,i} xu,i 表示深度为 i i i 的不同颜色个数,这个数组应当是数组 d u , i d_{u,i} du,i 的桶,也即 x u , i = ∑ j = 1 n [ d u , j = i ] x_{u,i}=\sum\limits_{j=1}^n [d_{u,j}=i] xu,i=j=1∑n[du,j=i]。我们询问时只需要查询 ∑ j = dep x dep x + d x u , j \sum\limits_{j=\text{dep}_x}^{\text{dep}_x+d} x_{u,j} j=depx∑depx+dxu,j 即可。
我们到了 u u u,处理完所有儿子的信息,如何合并?假设枚举到儿子 v v v,我们先将两个桶相加,就是 x u , i ← x u , i + x v , i x_{u,i}\leftarrow x_{u,i}+x_{v,i} xu,i←xu,i+xv,i。接下来枚举所有颜色,如果两个数组都有这个颜色,则因为我们只保留深度最小的,我们就需要将 max { d u , i , d v , i } \max\{d_{u,i},d_{v,i}\} max{du,i,dv,i} 除去,也即 x max { d u , i , d v , i } ← x max { d u , i , d v , i } − 1 x_{\max\{d_{u,i},d_{v,i}\}}\leftarrow x_{\max\{d_{u,i},d_{v,i}\}}-1 xmax{du,i,dv,i}←xmax{du,i,dv,i}−1,同时 d u , i ← min { d u , i , d v , i } d_{u,i}\leftarrow \min\{d_{u,i},d_{v,i}\} du,i←min{du,i,dv,i}。
明眼人已经看出这个过程可以使用线段树合并来维护。
现在丧心病狂的一点来了,这题要强制在线。
因此我们需要保留每个点的线段树,所以我们需要:可持久化线段树合并。
这事啥意思?且看下图:
我们的解决方案是:合并两个线段树时,对于每个公共的点,新建一个节点,在其上修改。
由于线段树合并的时间复杂度是正确的,因此这么做的空间复杂度也是正确的。
时间 / 空间复杂度都是 O ( n log n ) O(n\log n) O(nlogn)。
代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MAXN = 100005;
int T, N, M, fa[MAXN], c[MAXN], dep[MAXN];
struct node { int v, next; } E[MAXN]; int head[MAXN], Elen;
void add(int u, int v) { ++Elen, E[Elen].v = v, E[Elen].next = head[u], head[u] = Elen; }
void dfs(int u) {
dep[u] = dep[fa[u]] + 1;
for (int i = head[u]; i; i = E[i].next) dfs(E[i].v);
}
namespace T1 {
int val[MAXN * 100], ls[MAXN * 100], rs[MAXN * 100], Tlen;
#define mid ((l + r) >> 1)
void newnode(int& o) { ++Tlen, val[Tlen] = val[o], ls[Tlen] = ls[o], rs[Tlen] = rs[o], o = Tlen; }
void insert(int& o, int l, int r, int pos, int k) {
newnode(o);
if (l == r) val[o] += k;
else {
if (pos <= mid) insert(ls[o], l, mid, pos, k);
else insert(rs[o], mid + 1, r, pos, k);
val[o] = val[ls[o]] + val[rs[o]];
}
}
void merge(int& o, int l, int r, int old) {
if (!old) return;
if (!o) { o = old; return; }
newnode(o);
if (l == r) val[o] += val[old];
else merge(ls[o], l, mid, ls[old]), merge(rs[o], mid + 1, r, rs[old]), val[o] = val[ls[o]] + val[rs[o]];
}
int query(int o, int l, int r, int L, int R) {
if (!o) return 0;
if (l == L && r == R) return val[o];
else {
if (R <= mid) return query(ls[o], l, mid, L, R);
else if (L > mid) return query(rs[o], mid + 1, r, L, R);
else return query(ls[o], l, mid, L, mid) + query(rs[o], mid + 1, r, mid + 1, R);
}
}
void clear() {
for (int i = 1; i <= Tlen; ++i) val[i] = ls[i] = rs[i] = 0;
Tlen = 0;
}
};
namespace T2 {
int val[MAXN * 20], ls[MAXN * 20], rs[MAXN * 20], Tlen;
#define mid ((l + r) >> 1)
void newnode(int& o) { o = ++Tlen, val[o] = ls[o] = rs[o] = 0; }
void insert(int& o, int l, int r, int pos, int k, int& aim) {
if (!o) newnode(o);
if (l == r) {
if (val[o]) T1::insert(aim, 1, N, max(val[o], k), -1), val[o] = min(val[o], k);
else val[o] = k;
}
else {
if (pos <= mid) insert(ls[o], l, mid, pos, k, aim);
else insert(rs[o], mid + 1, r, pos, k, aim);
}
}
void merge(int& o, int l, int r, int old, int& aim) {
if (!old) return;
if (!o) { o = old; return; }
if (l == r) T1::insert(aim, 1, N, max(val[o], val[old]), -1), val[o] = min(val[o], val[old]);
else merge(ls[o], l, mid, ls[old], aim), merge(rs[o], mid + 1, r, rs[old], aim);
}
int query(int& o, int l, int r, int pos) {
if (!o) return 0;
if (l == r) return val[o];
else if (pos <= mid) return query(ls[o], l, mid, pos);
else return query(rs[o], mid + 1, r, pos);
}
};
int rt1[MAXN], rt2[MAXN];
void prepare(int u) {
T1::insert(rt1[u], 1, N, dep[u], 1), T2::insert(rt2[u], 1, N, c[u], dep[u], rt1[u]);
for (int i = head[u]; i; i = E[i].next) prepare(E[i].v), T1::merge(rt1[u], 1, N, rt1[E[i].v]), T2::merge(rt2[u], 1, N, rt2[E[i].v], rt1[u]);
}
int main() {
scanf("%d", &T);
while (T--) {
scanf("%d%d", &N, &M); int i, x, d, last = 0;
for (i = 1; i <= N; ++i) scanf("%d", &c[i]);
for (i = 2; i <= N; ++i) scanf("%d", &fa[i]), add(fa[i], i);
dfs(1), prepare(1);
while (M--) {
scanf("%d%d", &x, &d), x ^= last, d ^= last;
last = T1::query(rt1[x], 1, N, dep[x], dep[x] + d);
printf("%d\n", last);
}
T1::clear(), T2::Tlen = 0;
Elen = 0;
for (i = 1; i <= N; ++i) head[i] = rt1[i] = rt2[i] = 0;
}
return 0;
}