题意:n次操作,有三种操作:加入一个集合中没有的数,删除一个集合中已有的数,求集合中位置模5余3的数之和(集合自动从小到大排序)。
因为操作最多是1e5次,但是数的范围最大到1e9,所以要采用离散化。线段树中每个结点有一个长度为5的数组,来记录模5的余数,还有一个m数组,若在第i个位置插入或删除一个数,则该位置后面的每个结点的m都要加1或者是减1,来记录该结点位置模5后余数的变化。
这道题敲了很久,离散化的时候就把自己绕晕了,之后线段树的pushup函数也让我绕了好一会儿才绕清楚。敲这个题的时候一定要思路清晰呀。。
#include <cstdio>
#include <iostream>
#include <algorithm>
#define maxn 100005
#define ls node <<1
#define rs node << 1 | 1
#define lson l, mid, ls
#define rson mid + 1, r, rs
using namespace std;
int n, cnt, val[maxn], m[maxn << 2];
char str[5];
long long sum[maxn <<2][5];
struct Operation
{
int o;
int v;
} op[maxn];
struct add_and_delete
{
int id;
int v;
int r;
} o[maxn];
bool cmp(add_and_delete x, add_and_delete y)
{
return x.v < y.v;
}
void pushup(int l, int r, int node)
{
long long t = 0;
if(l == r)
{
for(int i = 0; i < 5; i++)
if(sum[node][i])
t = sum[node][i];
for(int i = 0; i < 5; i++)
sum[node][i] = 0;
sum[node][m[node] % 5] = t;;
return;
}
for(int i = 0; i < 5; i++)
sum[node][(i + m[node]) % 5] = sum[ls][i] + sum[rs][i];
}
void build(int l, int r, int node)
{
m[node] = 0;
for(int i = 0; i < 5; i++)
sum[node][i] = 0;
if(l == r)
return;
int mid = (l + r) >> 1;
build(lson);
build(rson);
}
void update(int x, int y, int v, int l, int r, int node)
{
if(x <= l && y >= r)
{
m[node] += v;
pushup(l, r, node);
return;
}
int mid = (l + r) >>1;
if(x <= mid)
update(x, y, v, lson);
if(y > mid)
update(x, y, v, rson);
pushup(l, r, node);
}
void recover(int x, int v, int l, int r, int node)
{
if(l == r)
{
for(int i = 0; i < 5; i++)
sum[node][i] = 0;
if(v == 1)
sum[node][m[node] % 5] = val[l];
return;
}
int mid = (l + r) >>1;
if(x <= mid)
recover(x, v, lson);
if(x > mid)
recover(x, v, rson);
pushup(l, r, node);
}
int main()
{
while(scanf("%d", &n) != EOF)
{
cnt = 0;
for(int i = 0; i < n; i++)
{
scanf("%s", str);
if(str[0] == 's')
{
op[i].o = 0;
}
else if(str[0] == 'a')
{
op[i].o = 1;
scanf("%d", &op[i].v);
o[cnt].id = i;
o[cnt++].v = op[i].v;
}
else if(str[0] == 'd')
{
op[i].o = -1;
scanf("%d", &op[i].v);
o[cnt].id = i;
o[cnt++].v = op[i].v;
}
}
sort(o, o + cnt, cmp);
o[0].r = 0;
val[0] = o[0].v;
for(int i = 1; i < cnt; i++)
{
o[i].r = o[i - 1].r;
if(o[i].v != o[i - 1].v)
{
o[i].r ++;
val[o[i].r] = o[i].v;
}
}
for(int i = 0; i < cnt; i++)
op[o[i].id].v = o[i].r;
build(0, o[cnt - 1].r, 1);
for(int i = 0; i < n; i++)
{
if(op[i].o == 0)
printf("%lld\n", sum[1][3]);
else
{
update(op[i].v, o[cnt - 1].r, op[i].o, 0, o[cnt - 1].r, 1);
recover(op[i].v, op[i].o, 0, o[cnt - 1].r, 1);
}
}
}
}