题目描述
要你维护一个集合,集合里面的元素为二元组(x,y)。这个集合中一开始有n个元素。
支持如下操作:
- Insert,x0,y0, 插入一个元素(x0,y0)。
- Delete,x0 ,删除所有x值为x0的元素。
- Querymin1,查询所有元素中x值最小的元素。
- Querymin2,查询所有元素中y值最小的元素。
- Query1,x0,查询集合中是否有x值为x0的元素。
- Query2,x0,查询集合中x值大于等于x0的第一个元素(不存在的话输出"-1")。
- Query3,x0,查询x值小于等于x0的所有元素中y值的最小值(不存在的话输出"-1")。
- Query4,k,查询集合中x值第k大的元素。
- Query5,k,查询x值从小到大前k个元素的y值的最大值。
- Query6,k,查询x值从小到大前k个元素的y值的和。
保证输入的x中没有重复。
数据范围:n<=1e5,操作个数Q<=1e5。
输入格式
第一行输入两个数n,Q,后面n行每行输入两个整数x,y。
后面Q行每行进行一个操作。
输出格式
输出操作3-10中的查询结果。
这道题教会了我如果用平衡树来维护一些简单的辅助信息,极值、区间和等等(这道题我调了几乎一整天)。
代码如下
#include <iostream>
#include <cstdio>
#include <cmath>
#include <string>
#include <cstring>
#include <stack>
#include <map>
#include <queue>
#include <vector>
#include <set>
#include <algorithm>
#include <iomanip>
#define LL long long
#define PII pair<int,int>
#define L tr[u].l
#define R tr[u].r
using namespace std;
const int N=2e5+5,INF=1e9;
struct Node{
int l,r;
int x,y,val; //x为key,y为辅助信息
int size,max,min=1e9; //max、min记录y的最大值、最小值
LL sum; //记录以该节点为根的子树中y的和
}tr[N];
int root,idx;
void pushup(int u) //由左右儿子节点推出父亲节点的信息
{
tr[u].size=tr[L].size+tr[R].size+1;
tr[u].sum=tr[L].sum+tr[R].sum+tr[u].y;
tr[u].max=max(tr[u].y,max(tr[L].max,tr[R].max));
tr[u].min=min(tr[u].y,min(tr[L].min,tr[R].min));
}
int getNode(int key,int y) //创建新节点
{
tr[++idx].x=key;
tr[idx].y=y;
tr[idx].max=tr[idx].min=tr[idx].sum=y;
tr[idx].size=1;
tr[idx].val=rand();
return idx;
}
void zig(int &u) //右旋
{
int q=L;
L=tr[q].r,tr[q].r=u;
u=q;
pushup(R);
pushup(u);
}
void zag(int &u) //左旋
{
int q=R;
R=tr[q].l,tr[q].l=u;
u=q;
pushup(L);
pushup(u);
}
void insert(int &u,int key,int y) //插入一个节点(操作1)
{
if(!u) u=getNode(key,y);
else if(key<tr[u].x)
{
insert(L,key,y);
if(tr[L].val>tr[u].val) zig(u);
}
else
{
insert(R,key,y);
if(tr[R].val>tr[u].val) zag(u);
}
pushup(u);
}
void remove(int &u,int key) //删除一个节点(操作2)
{
if(!u) return;
if(tr[u].x==key)
{
if(L||R)
{
if(!R||tr[L].val>tr[R].val)
{
zig(u);
remove(R,key);
}
else
{
zag(u);
remove(L,key);
}
}
else u=0;
}
else if(key<tr[u].x) remove(L,key);
else remove(R,key);
if(u) pushup(u);
}
Node get(int &u,int key) //判断key节点十分存在(操作5)
{
if(!u) return Node();
if(tr[u].x==key) return tr[u];
if(tr[u].x>key) return get(L,key);
return get(R,key);
}
int find(int &u,int rank) //找出树中排名为rank的节点(操作4和操作8)
{
if(!u) return -1;
if(tr[L].size>=rank) return find(L,rank);
if(tr[L].size+1>=rank) return tr[u].x;
return find(R,rank-tr[L].size-1);
}
int getPrev(int u,int key) //找出平衡树中小于等于key的最大值(操作7)
{
if(!u) return -1;
if(key<tr[u].x) return getPrev(L,key);
return max(tr[u].x,getPrev(R,key));
}
int getNext(int u,int key) //找出平衡树中大于等于key的最小值(操作6)
{
if(!u) return INF;
if(key>tr[u].x) return getNext(R,key);
return min(tr[u].x,getNext(L,key));
}
int findMax(int &u,int rank) //找出树中x排名为1-rank的y的最大值(操作9)
{
if(!u) return -INF;
if(tr[L].size>=rank) return findMax(L,rank);
if(tr[L].size+1>=rank) return max(tr[L].max,tr[u].y);
return max(tr[L].max,max(tr[u].y,findMax(R,rank-1-tr[L].size)));
}
int getMin(int &u,int key) //找出树中x小于等于key的所有数的最小值(操作7)
{
if(!u) return INF;
if(key<tr[u].x) return getMin(L,key);
if(key==tr[u].x) return min(tr[L].min,tr[u].y);
return min(tr[L].min,min(tr[u].y,getMin(R,key)));
}
LL findSum(int &u,int rank) //找出树中x排名为1-rank的y的和(操作10)
{
if(!u) return 0;
if(tr[L].size>=rank) return findSum(L,rank);
if(tr[L].size+1>=rank) return tr[L].sum+tr[u].y;
return tr[L].sum+tr[u].y+findSum(R,rank-1-tr[L].size);
}
int main()
{
int n,m;
scanf("%d %d",&n,&m);
for(int i=1;i<=n;i++)
{
int x,y;
scanf("%d %d",&x,&y);
insert(root,x,y);
}
while(m--)
{
int op,x,y;
scanf("%d",&op);
if(op==1) //插入
{
n++;
scanf("%d %d",&x,&y);
insert(root,x,y);
}
else if(op==2) //删除
{
n--;
scanf("%d",&x);
remove(root,x);
}
else if(op==3) //x的最小值
{
printf("%d\n",find(root,1));
}
else if(op==4) //y的最小值
{
printf("%d\n",tr[root].min);
}
else if(op==5) //x是否存在
{
scanf("%d",&x);
Node t=get(root,x);
if(t.x) puts("YES");
else puts("NO");
}
else if(op==6) //x的后缀
{
scanf("%d",&x);
int t=getNext(root,x);
if(t>=INF) t=-1;
printf("%d\n",t);
}
else if(op==7) //小于等于x的数中y的最小值
{
scanf("%d",&x);
int key=getPrev(root,x);
printf("%d\n",getMin(root,key));
}
else if(op==8) //x的第k大元素
{
scanf("%d",&x);
printf("%d\n",find(root,n-x+1));
}
else if(op==9) //前k个x中y的最大值
{
scanf("%d",&x);
printf("%d\n",findMax(root,x));
}
else //前k个x中y的和
{
scanf("%d",&x);
printf("%lld\n",findSum(root,x));
}
}
return 0;
}