题目链接:https://www.luogu.org/problem/P1972
题目描述
HH 有一串由各种漂亮的贝壳组成的项链。HH 相信不同的贝壳会带来好运,所以每次散步完后,他都会随意取出一段贝壳,思考它们所表达的含义。HH 不断地收集新的贝壳,因此,他的项链变得越来越长。有一天,他突然提出了一个问题:某一段贝壳中,包含了多少种不同的贝壳?这个问题很难回答……因为项链实在是太长了。于是,他只好求助睿智的你,来解决这个问题。
输入格式
第一行:一个整数N,表示项链的长度。
第二行:N 个整数,表示依次表示项链中贝壳的编号(编号为0 到1000000 之间的整数)。
第三行:一个整数M,表示HH 询问的个数。
接下来M 行:每行两个整数,L 和R(1 ≤ L ≤ R ≤ N),表示询问的区间。
输出格式
M 行,每行一个整数,依次表示询问对应的答案。
输入输出样例
输入 #1复制
6
1 2 3 4 3 5
3
1 2
3 5
2 6
输出 #1复制
2
2
4
说明/提示
对于20%的数据,n,m≤5000
对于40%的数据,n,m≤10^5
对于60%的数据,n,m≤5×10^5
对于所有数据,n,m≤1×10^6
本题可能需要较快的读入方式,最大数据点读入数据约20MB
考虑前缀和。
用线段树或树状数组维护不同贝壳数的前缀和,那么我们可以轻松查询出在区间上的不同贝壳数。
但是如何把前缀和转化为区间和呢?
假设区间为,我们能轻易求出
,但是由于统计
时,对于
段与之前重复的贝壳不统计,因此不能直接用
作为答案。
比方说对于项链1 1 1 1 1,;区间
显然有一种不同的贝壳,但并非
种。
所以,我们的思路是:补偿因为重复而造成的统计上的损失。
那么只需从左到右扫一遍,设某个询问的区间左端点为,由于
右边的区间均需要补偿,所以对于
区间内的所有数,都要对在它后面一个的相等的数进行add操作。
具体来说,就是对询问进行离线处理,将询问按区间左端点进行排序;同时预处理数组,
表示与第
个贝壳相同的下一个贝壳。这是本题最核心的部分,代码大致就是这样:
for (int i=1;i<=m;i++)
{
while (loc<query[i].left)
{
if (nxt[loc]!=n+1)
add(nxt[loc],1);
loc++;
}
ans[query[i].label]=sum(query[i].right)-sum(query[i].left-1);
}
其中表示当前扫到的位置,可以看到,通过对
进行补偿,原本没有被统计到的重复的贝壳被统计了;另外,由于补偿影响前缀和,因此会让后面所有的前缀和都+1,即使补偿的贝壳类型不再出现,也会因为前缀和相减而抵消,不会对后续询问造成影响。
这题还是值得消化一下的,如果看不懂的话,可以先意会一下,调一调代码,就明白为什么了。
完整代码:
#include<cstdio>
#include<algorithm>
using namespace std;
struct newdata
{
int left,right,label;
};
int n,m;
bool occur[1000001];
int a[1000001];
int tree[1000001];
int nxt[1000001];
int now[1000001];
int ans[1000001];
newdata query[1000001];
bool cmp(newdata i,newdata j)
{
return i.left<j.left;
}
int lowbit(int x)
{
return x&(-x);
}
void add(int i,int x)
{
while (i<=n)
{
tree[i]+=x;
i+=lowbit(i);
}
return;
}
int sum(int i)
{
int x=0;
while (i>0)
{
x+=tree[i];
i-=lowbit(i);
}
return x;
}
int main()
{
scanf("%d",&n);
for (int i=1;i<=n;i++) //预处理前缀和
{
int x;
scanf("%d",&x);
a[i]=x;
if (!occur[x])
{
occur[x]=true;
add(i,1);
}
}
scanf("%d",&m);
for (int i=1;i<=m;i++)
{
scanf("%d%d",&query[i].left,&query[i].right);
query[i].label=i;
}
sort(query+1,query+m+1,cmp); //预处理询问
for (int i=n;i>=1;i--) //预处理nxt数组
if (!now[a[i]])
{
now[a[i]]=i;
nxt[i]=n+1;
}
else
{
nxt[i]=now[a[i]];
now[a[i]]=i;
}
int loc=1;
for (int i=1;i<=m;i++) //核心代码
{
while (loc<query[i].left)
{
if (nxt[loc]!=n+1)
add(nxt[loc],1);
loc++;
}
ans[query[i].label]=sum(query[i].right)-sum(query[i].left-1);
}
for (int i=1;i<=m;i++)
printf("%d\n",ans[i]);
return 0;
}