Description
有一个序列含有一定数量的元素,现在要求写一个程序,满足以下几个要求: 【A】支持插入操作(这个序列不允许有重复元素,即是说,如果待插入的元素已经出现在这个序列中,那么请忽略此操作) 【B】支持删除操作(如果此序列中不包含待删除的这个元素,则忽略此操作,否则删除这个元素) 【C】查找元素x的前继元素(前继元素是指:小于x且与x最接近的元素,当然,如果x已经是序列中的最小元素,则x没有前继元素) 【D】查找元素x的后继元素(后继元素是指:大于x且与x最接近的元素,当然,如果x已经是序列中的最大元素,则x没有后继元素) 【E】找第K小的元素 【F】求某个元素x的秩(即x的排名是多少,从小到大排序)
Input
多组数据(整个文件以输入 -1 结束) 对于每组数据,有若干行(最多100000行),表示的意义如下: 【A】 insert x 【B】 delete x 【C】 predecessor x 【D】 successor x 【E】 Kth x 【F】 rank x 这6种操作的意义与上面的定义相对应! 【G】 print 表示从小到大输出序列中的所有元素 【H】 end 表示结束本组数据 每组输入数据后有一空行!
Output
对于以上8种操作,分别输出对应信息,如下: 【A】 insert x 不用输出任何信息 【B】 delete x 如果x存在,则删除x,否则输出 Input Error 【C】 predecessor x 如果x不存在,输出 Input Error;否则如果x是序列中的最小元素,输出对应信息(见样例),否则输出x的前继元素 【D】 successor x 如果x不存在,输出 Input Error;否则如果x是序列中的最大元素,输出对应信息(见样例),否则输出x的后继元素 【E】 Kth x 如果x不合法,输出 Input Error;否则输出第Kth小的元素(见样例) 【F】 rank x 如果x不存在,输出 Input Error;否则输出x的排名(见样例) 【G】 print 从小到大输出序列中的所有元素,每个元素后加一个逗号,并在最后加上 end of print(见样例) 【H】 end 输出 end of this test
insert 20 insert 5 insert 1 insert 15 insert 9 insert 25 insert 23 insert 30 insert 35 print Kth 0 Kth 1 Kth 3 Kth 5 Kth 7 Kth 9 Kth 10 rank 1 rank 3 rank 5 rank 15 rank 20 rank 30 rank 31 rank 35 successor 15 successor 35 successor 25 successor 26 predecessor 1 predecessor 20 predecessor 23 predecessor 15 predecessor 111 delete 9 delete 15 delete 25 delete 23 delete 20 print Kth 3 Kth 4 rank 30 rank 35 end -1
1,5,9,15,20,23,25,30,35,end of print Input Error The 1_th element is 1 The 3_th element is 9 The 5_th element is 20 The 7_th element is 25 The 9_th element is 35 Input Error The rank of 1 is 1_th Input Error The rank of 5 is 2_th The rank of 15 is 4_th The rank of 20 is 5_th The rank of 30 is 8_th Input Error The rank of 35 is 9_th The successor of 15 is 20 35 is the maximum The successor of 25 is 30 Input Error 1 is the minimum The predecessor of 20 is 15 The predecessor of 23 is 20 The predecessor of 15 is 9 Input Error 1,5,30,35,end of print The 3_th element is 30 The 4_th element is 35 The rank of 30 is 3_th The rank of 35 is 4_th end of this test
第一次写treap
#include <iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
struct node
{
node *ch[2];//左右儿子
int r,v,size;
int cmp(int x)
{
if(x==v) return -1;
return x<v?0:1;
}
void maintain()//统计节点个数
{
size=1;
if(ch[0]!=nullptr) size+=ch[0]->size;
if(ch[1]!=nullptr) size+=ch[1]->size;
}
};
void rotate(node *&o, int d)//d=0 左旋,d=1 右旋
{
node *k = o->ch[1 ^ d];
o->ch[1 ^ d] = k->ch[d];
k->ch[d] = o;
o->maintain();
k->maintain();
o = k;
}
void insert(node*&o, int x)//插入
{
if (o == nullptr)
{
o = new node();
o->ch[0] = o->ch[1] = nullptr;
o->v = x;
o->r = rand();
}
else
{
int d = o->cmp(x);
if (d == -1) return;
insert(o->ch[d], x);
if (o->ch[d]->r > o->r)
rotate(o, 1 ^ d);
}
o->maintain();
}
void remove(node*&o, int x)//删除
{
int d = o->cmp(x);
if (d == -1)
{
if (o->ch[0] != nullptr&&o->ch[1] != nullptr)
{
int d2 = (o->ch[0]->r > o->ch[1]->r ? 1 : 0);
rotate(o, d2);
remove(o->ch[d2],x);
}
else
{
if (o->ch[0] == nullptr) o = o->ch[1];
else o = o->ch[0];
}
}
else remove(o->ch[d], x);
if(o!=nullptr) o->maintain();
}
node *find(node*o, int x)//查找
{
if (o == nullptr) return 0;
if (o->v == x) return o;
else if (o->v > x)
return find(o->ch[0], x);
else return find(o->ch[1], x);
}
void clear(node *&o)//清空
{
if (o == nullptr) return;
if (o->ch[0] != nullptr) clear(o->ch[0]);
if (o->ch[1] != nullptr) clear(o->ch[1]);
o = nullptr;
}
void print(node *o)//中序遍历
{
if (o)
{
print(o->ch[0]);
printf("%d,", o->v);
print(o->ch[1]);
}
}
int findmin(node *o)//最小值
{
while (o->ch[0])
o = o->ch[0];
return o->v;
}
int findmax(node *o)//最大值
{
while (o->ch[1])
o = o->ch[1];
return o->v;
}
node *pre(node *o, int x, node *p)//查找x的前驱元素
{
if (!o )
return p;
if (o->v < x)
return pre(o->ch[1], x, o);
else return pre(o->ch[0], x, p);
}
node *suc(node*o, int x, node*p)//查找x的后继元素
{
if (!o )
return p;
if (o->v > x)
return suc(o->ch[0], x, o);
else return suc(o->ch[1], x, p);
}
node *kth(node *o,int k)//第k小
{
if(o==nullptr) return 0;
int temp;
if(o->ch[0]==nullptr) temp=1;
else temp=o->ch[0]->size+1;
if(k==temp) return o;
else if(k<temp) return kth(o->ch[0],k);
else return kth(o->ch[1],k-temp);
}
int Rank(node *o,int val,int cur)//排名
{
if(o==nullptr) return 0;
if(val< o->v) return Rank(o->ch[0],val,cur);
if(val> o->v)
{
if(o->ch[0]==nullptr) cur+=1;
else cur+=(o->ch[0]->size+1);
return Rank(o->ch[1],val,cur);
}
if(val==o->v)
{
if(o->ch[0]==nullptr) cur+=1;
else cur+=(o->ch[0]->size+1);
return cur;
}
}
int main()
{
char s[20];
node *root=nullptr;
while(scanf("%s",s)!=EOF)
{
if(s[0]=='-') break;
if(s[0]=='e')
{
printf("end of this test\n");
clear(root);
continue;
}
if(s[0]=='p'&&s[2]=='i')
{
print(root);
printf("end of print\n");
continue;
}
int x;
scanf("%d",&x);
if(s[0]=='i')
{
insert(root,x);
continue;
}
if(s[0]=='d')
{
if (find(root, x) == 0) printf("Input Error\n");
else remove(root,x);
}
if (s[0]=='p'&&s[2]=='e')
{
if (find(root, x) == 0) printf("Input Error\n");
else
{
node *pred;
int temp = findmin(root);
if (temp == x) printf("%d is the minimum\n",x);
else
{
pred = pre(root, x, 0);
printf("The predecessor of %d is %d\n",x, pred->v);
}
}
}
if (s[0]=='s')
{
if (find(root, x) == 0) printf("Input Error\n");
else
{
node *succ;
int temp = findmax(root);
if (temp == x) printf("%d is the maximum\n", x);
else
{
succ = suc(root, x, 0);
printf("The successor of %d is %d\n",x, succ->v);
}
}
}
if(s[0]=='K')
{
if(root==nullptr||x<=0||x>root->size) printf("Input Error\n");
else
{
node *result;
result=kth(root,x);
printf("The %d_th element is %d\n",x,result->v);
}
}
if(s[0]=='r')
{
if(find(root,x)==0) printf("Input Error\n");
else
{
int ans=Rank(root,x,0);
printf("The rank of %d is %d_th\n",x,ans);
}
}
}
return 0;
}