简单的平衡二叉树题,支持三个操作,插入、查询最值、删除,在结构体中用到了运算符重载,为了编码方便,如果不用重载,应该会跑得更快。 我的代码: #include <cstdio> #include <cstring> #include <cstdlib> #include <algorithm> using namespace std; const int MAX=1100000; struct In{ int id; int pro; In(){} In(int id,int pro):id(id),pro(pro){} bool operator==(const In& in) const{ return id==in.id&&pro==in.pro; } bool operator<(const In& in) const{ if(pro!=in.pro){ return pro<in.pro; }else{ return id<in.id; } } bool operator<=(const In& in) const{ return *this<in||*this==in; } bool operator>(const In& in) const{ return !(*this<=in); } bool operator>=(const In& in) const{ return !(*this<in); } }; struct Node { int left, right, size, cnt; In key; void init() { left = right = 0; size = 1; } } node[MAX]; int tol; int root; void init(){ tol=root=0; } void Lt(int &t) { int k = node[t].right; node[t].right = node[k].left; node[k].left = t; node[k].size = node[t].size; node[t].size = node[node[t].left].size + node[node[t].right].size + 1; t = k; return; } void Rt(int &t) { int k = node[t].left; node[t].left = node[k].right; node[k].right = t; node[k].size = node[t].size; node[t].size = node[node[t].left].size + node[node[t].right].size + 1; t = k; return; } void keep(int &t, bool flag) { if (flag == 0) { if (node[node[node[t].left].left].size > node[node[t].right].size) Rt(t); else if (node[node[node[t].left].right].size > node[node[t].right].size) { Lt(node[t].left); Rt(t); } else return; } else { if (node[node[node[t].right].right].size > node[node[t].left].size) Lt(t); else if (node[node[node[t].right].left].size > node[node[t].left].size) { Rt(node[t].right); Lt(t); } else return; } keep(node[t].left, 0); keep(node[t].right, 1); keep(t, 0); keep(t, 1); return; } void insert(int &t, const In& v) { if (t == 0) { t = ++tol; node[t].init(); node[t].key = v; } else { node[t].size++; if (v < node[t].key) insert(node[t].left, v); else insert(node[t].right, v); keep(t, v >= node[t].key); } return; } int del(int &t, const In& v) { if (!t) return 0; node[t].size--; if (v == node[t].key || v < node[t].key && !node[t].left || v > node[t].key && !node[t].right) { if (node[t].left && node[t].right) { int p = del(node[t].left, In(v.id+1,v.pro)); node[t].key = node[p].key; return p; } else { int p = t; t = node[t].left + node[t].right; return p; } } else return del(v < node[t].key ? node[t].left : node[t].right, v); } In select(int t, int k) { if (k <= node[node[t].left].size) return select(node[t].left, k); else if (k > node[node[t].left].size + 1) return select(node[t].right, k - node[node[t].left].size - 1); return node[t].key; } int getmax(int t) { while (node[t].right) t = node[t].right; return t; } int getmin(int t) { while (node[t].left) t = node[t].left; return t; } int main(){ int id,pro; init(); while(scanf("%d",&id),id){ if(id==1){ scanf("%d%d",&id,&pro); insert(root,In(id,pro)); }else if(id==2){ if(tol==0){ puts("0"); }else{ id=getmax(root); printf("%d/n",node[id].key.id); del(root,node[id].key); } }else{ if(tol==0){ puts("0"); }else{ id=getmin(root); printf("%d/n",node[id].key.id); del(root,node[id].key); } } } return 0; }