首先放上学长博客链接
感谢宇巨抛给光巨的题,本人在抛题现场/doge
题意:
给出一个长度为n的数组,有m个询问,每次询问给出一个区间,问这个区间内有多少个数x恰好出现x次
考虑将询问离线,对每一个询问的右端点,将其左端点以及询问的id进行保存,维护其左端点,用结果 segVal(l,r)
表示当前这一段[l,r]之间的合法方案的个数
然后观察序列[2,2,2,2]
假设我们用sum[]来记录维护的左端点的贡献,用树状数组来进行操作
开始时sum[1] -> sum[4] 全为0
r | ||||
---|---|---|---|---|
1 | 0 | 0 | 0 | 0 |
2 | 1 | 0 | 0 | 0 |
3 | -1 | 1 | 0 | 0 |
4 | 0 | -1 | 1 | 0 |
在实际的维护过程中,我们只需要遍历数组,然后将这个数放进对应的multiset里面,然后进行比较a[i] 和 size的关系让年后维护上面对应的操作即可
因为数组长度为n,如果当前数组元素 > n的时候,是没有办法进行操作的,只需要对a[i] <= n
的部分进行处理
假如加入当前元素之后,size == a[i],就需要把贡献加一下(+1)
假如加入当前元素之后,size == a[i] + 1,就需要把第一个的贡献去掉,加入从第二个开始(siz == a[i])的贡献
假如加入当前元素之后,size == a[i] + 2,就需要把前两个的贡献去掉,加入从第三个开始的贡献,然后把第一个从set中删掉
一直这样维护就好啦,过程就像是上面的[2,2,2,2]
的更新过程
Code:
#define lowbit(x) (x & (-x))
#define Clear(x,val) memset(x,val,sizeof x)
int n, m;
typedef pair<int, int> PII;
ll a[maxn], sum[maxn], ans[maxn];
vector<PII> vet[maxn];
void add(int pos, int val) {
while(pos <= n) {
sum[pos] += val;
pos += lowbit(pos);
}
}
ll getSum(int pos) {
ll ret = 0;
while(pos) {
ret += sum[pos];
pos -= lowbit(pos);
}
return ret;
}
ll segVal(int l, int r) {
return getSum(r) - getSum(l - 1);
}
multiset<int> st[maxn];
int main() {
n = read, m = read;
for(int i = 1; i <= n; i++) a[i] = read;
for(int i = 1; i <= m; i++) {
int l = read, r = read;
if(l > r) swap(l, r);
vet[r].push_back({l, i});
}
// puts("ok");
for(int i = 1; i <= n; i++) {
// debug(a[i]);
if(a[i] <= n) {
st[a[i]].insert(i);
int siz = st[a[i]].size();
if(siz == a[i]) {
int pos = *st[a[i]].begin();
add(pos, 1);
} else if(siz == a[i] + 1) {
int p1 = *st[a[i]].begin();
int p2 = *next(st[a[i]].begin());
int cnt = segVal(p1, p1);/// p1 - 1 - > p1
// int cnt = getSum(p1) - getSum(p1 - 1);
add(p1, -(cnt + 1));
add(p2, 1);
} else if(siz == a[i] + 2) {
int p1 = *st[a[i]].begin();
int p2 = *next(st[a[i]].begin());
int p3 = *next(next(st[a[i]].begin()));
int cnt1 = segVal(p1, p1);
int cnt2 = segVal(p2, p2);
add(p1, -cnt1);
add(p2, -(cnt2 + 1));
add(p3, 1);
st[a[i]].erase(p1);
}
}
for(auto at : vet[i]) {
ans[at.second] = segVal(at.first,i);
}
}
for(int i = 1; i <= m; i++) {
printf("%lld\n", ans[i]);
}
return 0;
}
/**
7 2
3 1 2 2 3 3 7
1 7
3 4
**/