题意:
给定一个 数组 a[N]
,请你统计 L
到 R
中不同数字的个数。
思路:
与之前两道题不同,本题维护的是区间线段树,而不是权值线段树。
定义:
- 定义
t[u].sum
表示u
这个节点代表的区间[l, r]
中不同数字的个数,定义t[u].l
表示u
这个节点代表的区间的左边界,t[u].r
同理。
更新线段树的办法:
- 建立
n
棵线段树,每个线段树 是以 每个位置的数 为根, 建立 第i
个线段树 时,如果 数字a[i]
之前没有出现过,那么我们 直接以位置i
为划分依据,在线段树上 包含这个位置的加1
即可,表示 这个区间上又多了一种不同的数。(用一个 标记数组st
,来标识 元素a[i]
是否出现过,为-1
代表 未出现过) - 但是如果 数字
a[i]
出现过,我们先求出 上一个同样数 在哪里出现,我们就在 第i
个线段树上 以那个 同样数出现的位置 为划分依据 减去1
,再在 第i
个线段树上 包含位置的i
的区间加1
,这样我们就保证了 区间中数字不重复,且只保留最后一个(好处:在保证了 每个数字只会保留一个位置的情况下,保留的还是最右边的,这样就 更容易落入[l, r]
这个查询区间)。
int insert(int p, int l, int r, int pos, int ok)
{
int q = ++idx;
t[q] = t[p];
if (l == r) {
t[q].sum += ok;
return q;
}
int mid = l + r >> 1;
if (pos <= mid) t[q].l = insert(t[p].l, l, mid, pos, ok);
else t[q].r = insert(t[p].r, mid + 1, r, pos, ok);
t[q].sum = t[t[q].l].sum + t[t[q].r].sum;
return q;
}
memset(st, -1, sizeof st);
for (int i = 1; i <= n; ++i)
{
scanf("%d", &a[i]);
if (st[a[i]] == -1) {
root[i] = insert(root[i - 1], 1, n, i, 1);
}
else {
int t = insert(root[i - 1], 1, n, st[a[i]], -1);
root[i] = insert(t, 1, n, i, 1);
}
st[a[i]] = i;
}
上面的代码片段 就是实现 主席树的边插入边建树 了,insert
函数 的作用是:把新的线段树 q
在继承上一个线段树 p
同时,在新线段树 q
的代表区间包含 pos
的节点上的值加上 ok
。(ok
的值 可以是 1
或 -1
)
查询答案的办法:
- 假设要 查询
[l, r]
这个区间上不同数的个数。我们可以 以l
(左边界L
,不是1
) 为划分依据,在 第r
个线段树上 进行查询。当需要 往左递归计算 时,右边的区间 一定 完整包含在我们的查询区间[l, r]
(因为我们是 在第r
棵线段树上面查询,区间[r + 1, n]
的信息 没有加入到 第r
棵线段树),我们 直接加上右区间的sum
即可。假如需要 往右区间递归,那么直接递归就好了,左边的sum
不用加上,因为 左边区间中[1, l + 1]
在第r
棵线段树中。 - 这样处理,我们就可以得到:完整区间
[l, r]
上不同数的个数。
int ask(int q, int l, int r, int pos)
{
if (l == r) return t[q].sum;
int mid = l + r >> 1;
if (pos <= mid) return ask(t[q].l, l, mid, pos) + t[t[q].r].sum;
else return ask(t[q].r, mid + 1, r, pos);
}
代码:
#include <bits/stdc++.h>
using namespace std;
//#define map unordered_map
//#define int long long
const int N = 3e4 + 10, M = 1e6 + 10, Q = 2e5 + 10;
const int ALL = (N << 2) + Q * 15;
int st[M];
int a[N], n, m;
int root[N];
struct node
{
int l, r;
int sum;
} t[ALL];
int idx;
int build(int l, int r)
{
int q = ++idx;
if (l == r) return q;
int mid = l + r >> 1;
t[q].l = build(l, mid), t[q].r = build(mid + 1, r);
return q;
}
int insert(int p, int l, int r, int pos, int ok)
{
int q = ++idx;
t[q] = t[p];
if (l == r) {
t[q].sum += ok;
return q;
}
int mid = l + r >> 1;
if (pos <= mid) t[q].l = insert(t[p].l, l, mid, pos, ok);
else t[q].r = insert(t[p].r, mid + 1, r, pos, ok);
t[q].sum = t[t[q].l].sum + t[t[q].r].sum;
return q;
}
int ask(int q, int l, int r, int pos)
{
if (l == r) return t[q].sum;
int mid = l + r >> 1;
if (pos <= mid) return ask(t[q].l, l, mid, pos) + t[t[q].r].sum;
else return ask(t[q].r, mid + 1, r, pos);
}
signed main()
{
cin >> n;
root[0] = build(1, n);
memset(st, -1, sizeof st);
for (int i = 1; i <= n; ++i)
{
scanf("%d", &a[i]);
if (st[a[i]] == -1) {
root[i] = insert(root[i - 1], 1, n, i, 1);
}
else {
int t = insert(root[i - 1], 1, n, st[a[i]], -1);
root[i] = insert(t, 1, n, i, 1);
}
st[a[i]] = i;
}
cin >> m;
while (m--)
{
int l, r; scanf("%d%d", &l, &r);
printf("%d\n", ask(root[r], 1, n, l));
}
return 0;
}