Set.h
#include <algorithm>
#include <iostream>
#include <stack>
#include <utility>
template <typename Comparable>
class Set
{
private:
struct BinaryNode //结点定义
{
Comparable element;
BinaryNode* left, * right;
BinaryNode* prev, * next, * parent;//增加前驱后继及双亲结点
int height;
explicit BinaryNode(const Comparable& el = Comparable(),
BinaryNode* le = nullptr, BinaryNode* ri = nullptr,
BinaryNode* pr = nullptr, BinaryNode* ne = nullptr, BinaryNode* pa = nullptr, int h = 0)
:element(el), left(le), right(ri), prev(pr), next(ne), parent(pa), height(h) {};
};
//私有数据成员
BinaryNode* root; //根结点
BinaryNode* smallest; //头结点,(avl树不空的情况下其后继结点为最小元素所在结点)
BinaryNode* largest; //尾结点, (avl树不空的情况下其前驱结点为最大元素所在结点)
int theSize;//记录元素个数
void init()//初始化Set数据
{
root = nullptr;
smallest = new BinaryNode(Comparable{});
largest = new BinaryNode(Comparable{});
theSize = 0;
smallest->next = largest;
largest->prev = smallest;
}
public:
//构造、复制构造、赋值和析构函数(常见6种)
Set()
{
init();
}
Set(const Set& rhs)
{
init();
clone(rhs.root);
theSize = rhs.theSize;
}
Set(Set&& rhs) :root(rhs.root), smallest(rhs.smallest), largest(rhs.largest), theSize(rhs.theSize)
{
rhs.root = nullptr;
rhs.smallest = nullptr;
rhs.largest = nullptr;
rhs.theSize = 0;
}
Set& operator = (const Set& rhs)
{
Set copy = rhs;
std::swap(*this, copy);
return *this;
}
Set& operator = (Set&& rhs)
{
std::swap(root, rhs.root);
std::swap(largest, rhs.largest);
std::swap(smallest, rhs.smallest);
std::swap(theSize, rhs.theSize);
return *this;
}
~Set()
{
makeEmpty();
delete largest;
delete smallest;
}
//迭代器类
class const_iterator //const迭代器, 数据成员为指向结点的指针
{
protected:
BinaryNode* current;
const_iterator(BinaryNode* p) : current(p) {}
Comparable& retrieve() const
{
return current->element;
}
friend class Set;
public:
const_iterator() : current(nullptr) {}
const Comparable& operator* () const
{
return retrieve();
}
const_iterator& operator ++ ()
{
current = current->next;
return *this;
}
const_iterator operator ++ (int)
{
const_iterator old = *this;
++(*this);
return old;
}
const_iterator& operator -- ()
{
current = current->prev;
return *this;
}
const_iterator operator -- (int)
{
const_iterator old = *this;
--(*this);
return old;
}
bool operator == (const const_iterator& rhs) const
{
return current == rhs.current;
}
bool operator != (const const_iterator& rhs) const
{
return !(*this == rhs);
}
};
class iterator : public const_iterator //迭代器,是const迭代器的公有继承
{
protected:
iterator(BinaryNode* p) : const_iterator(p) {}
friend class Set;
public:
iterator() {};
Comparable& operator *()
{
return const_iterator::retrieve();
}
const Comparable& operator *() const
{
return const_iterator::operator*();
}
iterator& operator ++ ()
{
this->current = this->current->next;
return *this;
}
iterator operator ++ (int)
{
iterator old = *this;
++(*this);
return old;
}
iterator& operator --()
{
this->current = this->current->prev;
return *this;
}
iterator operator -- (int)
{
iterator old = *this;
--(*this);
return old;
}
};
//begin和end函数
const_iterator begin() const
{
return smallest->next;
}
iterator begin()
{
return smallest->next;
}
const_iterator end() const
{
return largest;
}
iterator end()
{
return largest;
}
//公有接口(函数),部分需要调用私有函数
const Comparable& findMin() const
{
return *findMin(root);
}
const Comparable& findMax() const
{
return *findMax(root);
}
bool contains(const Comparable& x) const
{
return contains(x, root);
}
bool isEmpty() const
{
return root == nullptr;
}
void printTree(ostream& out = cout) const
{
printTree(root, out);
}
void makeEmpty()
{
makeEmpty(smallest->next);
smallest->next = largest;
largest->prev = smallest;
}
int size() const
{
return theSize;
}
void clone(BinaryNode* t)
{
BinaryNode* pre = smallest;//设置前驱结点,便于中序遍历复制时建立前驱和后继链接。
root = clone(t, &pre);
pre->next = largest;// 调用私有clone后pre指向最大元素所在结点。
largest->prev = pre;
}
//重难点在于插入和删除结点链接的重新调整,以及对不平衡子树高度的调整,特别要注意一些细节处理。
//(1)插入
pair<iterator, bool> insert(const Comparable& x)//调用私有insert
{
bool bo = false;//bo用于返回值,如果插入成功则为true,否则为false
pair<iterator, bool> ret = std::make_pair(insert(x, root, smallest, largest, bo), bo);
return ret;
}
iterator insert(const_iterator iter, const Comparable& x)//iter给出插入位置信息,如果信息无效则调用上面的insert函数。
{
BinaryNode* cur = iter.current, * after_cur = cur->next, * before_cur = cur->prev,
* before, * after;
if ((before_cur == smallest || x > before_cur->element) && (cur == largest || cur->element > x))//插入位置介于--iter和iter之间
{
before = before_cur;
after = cur;
}
else if (after_cur != nullptr && x > cur->element && (after_cur == largest || x < after_cur->element))//插入位置介于iter和++iter之间
{
before = cur;
after = after_cur;
}
else//否则调用上面的insert函数
{
return insert(x).first;
}
//调整插入结点的前驱和后继链接
BinaryNode* add = new BinaryNode(x, nullptr, nullptr, before, after), * adjust;//adjust指向需要调整的结点
before->next = add;
after->prev = add;
//调整插入结点的双亲链接及双亲结点的左右孩子链接
if (before == smallest && after != largest)//如果插入在头结点的后面
{
after->left = add;
add->parent = after;
adjust = after;
}
else// 否则找到before结点右子树最左下结点,将插入结点作为右子树最左下结点的左孩子结点。
{
BinaryNode* rb = before->right;
if (rb == nullptr)
{
before->right = add;
add->parent = before;
adjust = before;
}
else
{
while (rb->left != nullptr)
{
rb = rb->left;
}
rb->left = add;
add->parent = rb;
adjust = rb;
}
}
//插入后对插入结点的双亲结点调整,循环往上,知道结点的高度和插入前相同停止调整
//!!!注意:插入和删除结点调整都采用以下形式。
while (adjust != nullptr)
{
int original_h = adjust->height;
BinaryNode* new_adj = balance(adjust);
if (original_h == new_adj->height)
{
break;
}
adjust = new_adj->parent;
}
theSize++;
return add;
}
template <class InputIterator>
void insert(InputIterator first, InputIterator last)
{
for (auto iter = first; iter != last; iter++)
{
insert(*iter);
}
}
//(2)删除
int remove(const Comparable& x)
{
int ret = 0; //返回删除元素个数,如果未删除为0,进行了删除操作则为1
remove(x, root, ret);
return ret;
}
iterator remove(const_iterator iter)
{
BinaryNode* cur = iter.current, * cur_next = cur->next, * cur_prev = cur->prev, * adjust;
iterator ret = iterator(cur_next);
//对删除结点前驱和后继结点的前驱和后继链接进行调整
cur_next->prev = cur_prev;
cur_prev->next = cur_next;
//对删除结点的双亲结点和双亲结点的左右孩子结点进行调整
//删除结点的左右子树不为空,则用其后继结点(右子树最左下结点)代替删除结点。
//注意不能采用这种方法:将删除结点赋值为其后继结点的值,再删除其后继结点。因为这样会导致指向其后继结点的迭代器失效,使用下一个remove函数时
//可以发现这个问题。
if (cur->left != nullptr && cur->right != nullptr)
{
BinaryNode* cnr = cur_next->right, * cnp = cur_next->parent, * cur_left = cur->left,
* cur_right = cur->right, * cur_par = cur->parent;
//用删除结点的后继结点替代删除结点,并调整后继结点的双亲和左右孩子结点。
cur_next->left = cur_left, cur_next->right = cur_right, cur_next->parent = cur_par;
cur_left->parent = cur_next;
cur_right->parent = cur_next;
if (cur_par == nullptr)
{
root = cur_next;
}
else
{
if (cur_par->left == cur)
{
cur_par->left = cur_next;
}
else
{
cur_par->right = cur_next;
}
}
cur_next->height = cur->height;
//对后继结点的双亲结点和其右孩子结点的双亲和左右孩子链接进行调整
if (cnp == cur)
{
cur_next->right = cnr;
if (cnr != nullptr)
{
cnr->parent = cur_next;
}
adjust = cur_next;
}
else
{
cnp->left = cnr;
if (cnr != nullptr)
{
cnr->parent = cnp;
}
adjust = cnp;
}
}
else//左右子树有一个为空,则用其左子树或右子树替代删除结点。
{
BinaryNode* cur_child = (cur->left != nullptr ? cur->left : cur->right);
BinaryNode* cur_par = cur->parent;
if (cur_child != nullptr)
{
cur_child->parent = cur_par;
}
if (cur_par == nullptr)
{
root = cur_child;
}
else
{
if (cur_par->left == cur)
{
cur_par->left = cur_child;
}
else
{
cur_par->right = cur_child;
}
}
adjust = cur_par;
}
delete cur;
theSize--;
while (adjust != nullptr)//删除结点后调整高度,采用前述形式。
{
int or_h = adjust->height;
BinaryNode* new_adj = balance(adjust);
if (or_h == new_adj->height)
{
break;
}
adjust = new_adj->parent;
}
return ret;
}
iterator remove(const_iterator first, const_iterator last)
{
const_iterator it = first;
while (it != last)
{
it = remove(it);
}
return it.current;
}
private:
//私有方法
static const int ALLOWED_IMBALANCE = 1; //常量,用于判断左右子树是否平衡。
BinaryNode* clone(BinaryNode* t, BinaryNode** pre)//中序遍历复制结点。
{
if (t == nullptr)
{
return nullptr;
}
stack<BinaryNode*> st1, st2;
BinaryNode* p1 = t, * p2 = new BinaryNode(t->element), * newroot = p2;
while (p1 != nullptr || !st1.empty())
{
if (p1 != nullptr)
{
st1.push(p1);
if (p1->left != nullptr)
{
p2->left = new BinaryNode(p1->left->element, nullptr, nullptr, nullptr, nullptr, p2);
}
st2.push(p2);
p1 = p1->left;
p2 = p2->left;
}
else
{
p1 = st1.top(); st1.pop(); p2 = st2.top(); st2.pop();
(*pre)->next = p2;
p2->prev = *pre;
*pre = p2;
if (p1->right != nullptr)
{
p2->right = new BinaryNode(p1->right->element, nullptr, nullptr, nullptr, nullptr, p2);
}
p1 = p1->right;
p2 = p2->right;
}
}
return newroot;
}
int height(BinaryNode* t) const //求结点高度,为nullptr时高度为-1
{
return t == nullptr ? -1 : t->height;
}
const_iterator findMin(BinaryNode* t) const
{
if (t == nullptr)
{
return nullptr;
}
while (t->left != nullptr)
{
t = t->left;
}
return t;
}
const_iterator findMax(BinaryNode* t) const
{
if (t == nullptr)
{
return nullptr;
}
while (t->right != nullptr)
{
t = t->right;
}
return t;
}
bool contains(const Comparable& x, BinaryNode* t) const
{
if (t == nullptr)
{
return false;
}
while (t != nullptr && t->element != x)
{
if (x < t->element)
{
t = t->left;
}
else
{
t = t->right;
}
}
return t != nullptr && t->element == x;
}
void printTree(BinaryNode* t, ostream& out) const//Morris中序遍历,有疑问可以参考leetcode二叉树搜索树中序遍历。
{
BinaryNode* p;
while (t != nullptr)
{
if (t->left != nullptr)
{
p = t->left;
while (p->right != nullptr && p->right != t)
{
p = p->right;
}
if (p->right == t)
{
out << t->element << endl;
p->right = nullptr;
t = t->right;
}
else
{
p->right = t;
t = t->left;
}
}
else
{
out << t->element << endl;
t = t->right;
}
}
}
void makeEmpty(BinaryNode* t)
{
BinaryNode* temp;
if (t != largest)
{
temp = t;
t = t->next;
delete temp;
}
}
iterator insert(const Comparable& x, BinaryNode*& t, BinaryNode* small, BinaryNode* large, bool& bo)
{
BinaryNode* ret = t;
if (t == nullptr)//根结点为nullptr时
{
t = new BinaryNode(x, nullptr, nullptr, small, large);
small->next = t;
large->prev = t;
theSize++;
bo = true;
return t;
}
BinaryNode* ct = t;
//找插入结点的前驱后继small、large及双亲结点ct。
while (x < ct->element && ct->left != nullptr || x > ct->element && ct->right != nullptr)
{
if (x < ct->element && ct->left != nullptr)
{
large = ct;
ct = ct->left;
}
else
{
small = ct;
ct = ct->right;
}
}
if (x != ct->element)
{
bo = true;
BinaryNode* add;
if (x < ct->element)
{
large = ct;
add = ct->left = new BinaryNode(x, nullptr, nullptr, small, large, ct);
}
else
{
small = ct;
add = ct->right = new BinaryNode(x, nullptr, nullptr, small, large, ct);
}
small->next = add;
large->prev = add;
theSize++;
ret = add;
while (ct != nullptr)//高度调整
{
int original_h = ct->height;
BinaryNode* new_ct = balance(ct);
if (original_h == new_ct->height)
{
break;
}
ct = new_ct->parent;
}
}
else
{
bo = false;
}
return ret;
}
void remove(const Comparable& x, BinaryNode*& t, int& ret)
{
if (t == nullptr)//根结点为nullptr,删除失败
{
return;
}
//找到删除结点
BinaryNode* ct = t;
while (x < ct->element && ct->left != nullptr || x > ct->element && ct->right != nullptr)
{
if (x < ct->element && ct->left != nullptr)
{
ct = ct->left;
}
else
{
ct = ct->right;
}
}
if (ct->element == x)
{
//删除结点的左右子树不为空,将其后继结点的值赋值到删除结点,并转换为删除后继结点。
//这里也可以直接调用迭代器版本删除,这里提供不同的方法供参考。
if (ct->left != nullptr && ct->right != nullptr)
{
ct->element = *findMin(ct->right);
remove(ct->element, ct->right, ret);
}
else//左右子树有一个为空,则用其左子树或右子树替代。
{
ret = 1;
theSize--;
BinaryNode* ct_par = ct->parent, * ct_prev = ct->prev, * ct_next = ct->next,
* ct_child = (ct->left != nullptr ? ct->left : ct->right);
//前驱后继结点调整
ct_prev->next = ct_next;
ct_next->prev = ct_prev;
//双亲和左右孩子结点调整
if (ct_child != nullptr)
{
ct_child->parent = ct_par;
}
if (ct_par == nullptr)
{
t = ct_child;
}
else
{
if (x < ct_par->element)
{
ct_par->left = ct_child;
}
else if (x >= ct_par->element)
{
ct_par->right = ct_child;
}
}
delete ct;
ct = ct_par;
while (ct != nullptr)//调整高度
{
int original_h = ct->height;
BinaryNode* new_ct = balance(ct);
if (original_h == new_ct->height)
{
break;
}
ct = new_ct->parent;
}
}
}
else
{
return;
}
}
BinaryNode* balance(BinaryNode* t)//高度调整,结点不平衡时根据四种情况调整
{
BinaryNode* ret = t;
if (t == nullptr)
{
return nullptr;
}
else if (height(t->left) - height(t->right) > ALLOWED_IMBALANCE)
{
if (height(t->left->left) >= height(t->left->right))
{
ret = rotateWithLeftChild(t);
}
else
{
ret = doubleWithLeftChild(t);
}
}
else if (height(t->right) - height(t->left) > ALLOWED_IMBALANCE)
{
if (height(t->right->right) >= height(t->right->left))
{
ret = rotateWithRightChild(t);
}
else
{
ret = doubleWithRightChild(t);
}
}
t->height = std::max(height(t->left), height(t->right)) + 1;
return ret;
}
BinaryNode* rotateWithLeftChild(BinaryNode* k2)//LL调整,只需要调整双亲和左右孩子结点,平衡调整不改变前驱后继结点。
{
BinaryNode* k1 = k2->left;
k2->left = k1->right;
if (k1->right != nullptr)
{
k1->right->parent = k2;
}
k1->right = k2;
k1->parent = k2->parent;
if (k2->parent != nullptr)
{
if (k1->element < k2->parent->element)
{
k2->parent->left = k1;
}
else
{
k2->parent->right = k1;
}
}
else
{
root = k1;
}
k2->parent = k1;
k2->height = std::max(height(k2->left), height(k2->right)) + 1;
k1->height = std::max(height(k1->left), height(k1->right)) + 1;
return k1;
}
BinaryNode* rotateWithRightChild(BinaryNode* k2)//RR调整
{
BinaryNode* k1 = k2->right;
k2->right = k1->left;
if (k1->left != nullptr)
{
k1->left->parent = k2;
}
k1->left = k2;
k1->parent = k2->parent;
if (k2->parent != nullptr)
{
if (k1->element < k2->parent->element)
{
k2->parent->left = k1;
}
else
{
k2->parent->right = k1;
}
}
else
{
root = k1;
}
k2->parent = k1;
k2->height = std::max(height(k2->left), height(k2->right)) + 1;
k1->height = std::max(height(k1->left), height(k1->right)) + 1;
return k1;
}
BinaryNode* doubleWithLeftChild(BinaryNode* k3)//LR调整
{
rotateWithRightChild(k3->left);
return rotateWithLeftChild(k3);
}
BinaryNode* doubleWithRightChild(BinaryNode* k3)//RL调整
{
rotateWithLeftChild(k3->right);
return rotateWithRightChild(k3);
}
};
Set.cpp
#incude<Set.h>
int main()
{
Set<int> st1;
for (int i = 10; i < 20; i++)
{
st1.insert(i);
}
for (int i = 0; i < 5; i++)
{
st1.insert(i);
}
vector<int> vint{ 888, 7, 9, -1, -3};
st1.insert(vint.begin(), vint.end());
cout << "Dispalying data in set:\n";
cout << "------------------" << endl;
for (auto x : st1)
{
cout << x << endl;
}
cout << endl;
cout << "After deletion:\n";
cout << "------------------" << endl;
st1.remove(st1.begin(), ------st1.end());
for (auto x : st1)
{
cout << x << endl;
}
cout << "Size:" << st1.size() << endl;
}