红黑树的c++实现[原创]

代码概述

  • 本文代码经过大量测试无误(测试代码见后文),可放心食用😋。
  • 使用 valgrind 检测没有内存泄露情形。
  • 本文代码与《算法导论》给出的算法有略微差别(虽然本人也没有看过👀),原因是本文的实现思路侧重于更清晰的逻辑,有部分分支情况其实是可以合并的 。

相关概念

  • 二叉树
template <class value_type>
struct tree_node {
  value_type m_data;
  tree_node* lchild;
  tree_node* rchild;
};

二叉搜索树

  • 二叉搜索树 (binary search tree, bst) 是一颗二叉树。
  • 左子树中所有节点小于本节点,右子树中所有节点大于(或大于等于)本节点。
    • 如果要求右子树中所有节点大于本节点,则树中所有元素互不相同。
    • 如果仅要求右子树中所有节点不小于本节点,则树中允许存在相同元素。

AVL树

  • 在二叉搜索树的基础上再加以限制:任一节点的左子树的高度与右子树的高度之差不能超过1
  • 由于子树高度差的限制,AVL是一颗深度均匀的树。

红黑树

  • AVL树是深度均匀的,因此AVL树中的查询操作非常高效,但由于子树高度差不能超过1的严苛限制,对AVL树进行插入和删除的效率偏低。
  • 红黑树的限制条件弱化了子树高度差的限制,使得插入和删除效率变高,但代价是牺牲了查询操作的最高效率。因而红黑树适用于插入和删除频繁的应用场景。
  • 红黑树的性质:
    • 性质一:红黑树的根节点是黑色的。
    • 性质二:红黑树中不存在相邻的红色节点。
    • 性质三:红黑树中所有从根节点到叶子节点的路径中包含的黑色节点的数目是一致的。
    • 性质四:红黑树中的叶子节点定义为nil,并且定义 nil 节点为黑色。即如果一个节点的左/右孩子指针为 nullptr,我们不再说它的左/右孩子为空,而是说它的左/右孩子为 nil 节点。
  • 红黑树的重要推论:如果一个节点是红黑树中的节点,那么以该节点为根节点的子树一定也满足性质二,三,四。
  • 红黑树性质分析:
    • 性质四可有可无,但性质四的存在有助于简化插入操作和删除操作的分类讨论。即不再需要考虑某一个非 nil 节点的左孩子节点/右孩子节点是否为空的问题,当左孩子节点/右孩子节点为空时,我们认为这个左孩子/右孩子节点是 nil 节点,则左孩子/右孩子为空时,我们认为这个左孩子/右孩子是黑色的(性质四规定 nil 节点是黑色的)。
    • 性质二和性质三共同作用限制了左右树高的差距,极端情况下一边子树的最长路径全部是黑色节点,另一边子树的最长路径是红色节点和黑色节点交替分布,此时较长一边子树是较矮一边子树高度的二倍。
    • 性质一是因为如果一颗红黑树一开始只有一个根节点 (非 nil 节点),假设我们一开始把这个根节点设置为红色,接下来如果我们插入一个新的节点,那么这个新的节点必然会成为根节点的左孩子/右孩子,此时这个新节点必须被赋为黑色(否则违反性质二),而新节点为黑色后又违反了性质三,所以我们不得不改变根节点的颜色为黑色,此时还需要再把新节点调整为红色,此时才满足了性质二和性质三,既然如此,我们可以一开始就把根节点限制为黑色,免得后面修修改改。

代码环境要求

  • 安装有支持 c++11 的编译器即可

结点定义

  • 在二叉树节点的基础上,添加了父亲节点指针和节点颜色两个字段。添加父亲节点指针是因为在插入和删除操作的调整中经常需要访问一个结点的父亲节点,所以我们直接使用一个字段保存节点的父亲节点。
  • 结点的定义借鉴了stl,我们在之前定义的红黑树的基础之上再添加一个头节点,头节点的存在只是为了方便对真正的红黑树进行管理,因而不需要存在数据域,因而我们定义了节点基类和节点类两个类型,头节点只需要是 rbnode_base 类型即可。我们让头节点的左孩子指向红黑树最左边的节点,让头节点的右孩子指向红黑树最右边的节点,让头节点的父亲节点指向红黑树的根节点,让根节点的父亲节点指向头节点。这样我们就可以很方便地通过头节点对红黑树进行访问。
  • 如果红黑树是空的,我们将头节点的父亲节点设置为 nil,将头节点的左孩子设置为头节点自己,右孩子也设置为头节点自己。
  • 在后文中,我们记头节点的父亲节点为 root(),头节点的左孩子为 leftmost(),右孩子为 rightmost()。显然 leftmost() 指向的就是红黑树中值最小的节点,rightmost() 指向的就是红黑树中值最大的节点。
  • 关于头节点:不仅在红黑树中,在链表的实现中也可以添加一个头节点以方便管理。
  • 下面给出了一颗添加了头节点的红黑树:
    在这里插入图片描述
// 定义颜色
enum class rbcolor { red = false, blk = true };

// 定义节点基类
struct rbnode_base {
  rbcolor m_color;
  rbnode_base* m_parent;
  rbnode_base* m_lchild;
  rbnode_base* m_rchild;
};

// 定义节点
template <typename T>
struct rbnode : public rbnode_base {
  T m_data;
};

插入节点

  • insert(const T& value)
    • 首先找到插入位置所在的节点 (采用二分法搜索即可)。记该节点为 pos,记待插入的值为 value,然后创建一个值为 value 的新节点,如果 value < pos->m_data,将新节点作为 pos 的左孩子节点链接到树中,否则将新节点作为 pos 的右孩子节点链接到树中。
    • 实现:
      bool insert(const T& value) {
      	node* pos = root();	// pos 初始化为根节点
      	node* pos_parent = head(); // 记录 pos 的父结点
      	while (pos != nullptr) {
      		pos_parent = pos;
      		if (value == pos->data())
      			return false;
      		else if (value < pos->data())
      			pos = pos->lchild();
      		else
      			pos = pos->rchild();
      	}
      	/**
      	 * 退出循环后 pos 为 nullptr (即 nil 节点)。
      	 * 我们根据 value 新建一个节点,然后用这个新节点代替 pos,也就是说
      	 * 如果 pos 是 pos_parent 的左孩子,那么我们将新节点链接为 pos_parent
      	 * 的左孩子,否则链接为 pos_parent 的右孩子。
      	 * pos_parent 的值一定不会和 value 相等,否则会在循环中直接返回 false.
      	 */
      	__insert(pos_parent, value);
      	++m_node_count;
      	return true;
      }
      
  • __insert(node* pos_parent, const T& value)
    • 首先考虑特殊情况:如果 pos_parent 是头节点,我们直接创建新节点 new_node,然后更新 leftmost() = rightmost() = root() = new_nodenew_node 颜色设置为黑色(性质1),然后插入操作就完成了,不过注意还需要设置 new_node 的左右孩子和父节点,然后直接返回。
    • 排除了特殊情况后,我们知道 insert 函数保证了调用 __insert 函数时, pos_parent 的值一定不会和 value 相等,那么我们分为两类进行讨论:
      • 如果 value < pos_parent->data():将 new_node 链接为 pos_parent 的左孩子,这里 pos_parent 有可能是 leftmost(),这样的话我们把 leftmost() 也顺带更新一下,注意 value < pos_parent ,所以 new_node 不可能成为新的 rightmost()
      • 如果 value > pos_parent->data():将 new_node 链接为 pos_parent 的右孩子,同样需要注意 rightmost() 是否需要更新。
    • 链接新节点的时候还需要注意设置新节点的父亲节点,后文同理。
    • 完成插入后,我们还有新节点的颜色未设置,不论如何设置新节点的颜色,新节点的插入都有可能打破红黑树的几条性质,所以我们把新节点的颜色交给 __preFixInsert(new_node) 函数设置。
    • 实现:
      void __insert(node* pos_parent, const T& value) {
      	node* new_node = create_node(value);
      	if (pos_parent == head()) {
      		leftmost() = rightmost() = root() = new_node;
      		new_node->setBlk();
      		new_node->parent() = pos_parent;
      		new_node->lchild() = nullptr;
      		new_node->rchild() = nullptr;
      		return;
      	}
      	if (value < pos_parent->data()) {
      		pos_parent->lchild() = new_node;
      		if (pos_parent == leftmost())
      			leftmost() = new_node;
      	} else {
      		pos_parent->rchild() = new_node;
      		if (pos_parent == rightmost())
      			rightmost() = new_node;
      	}
      	new_node->parent() = pos_parent;
      	new_node->lchild() = nullptr;
      	new_node->rchild() = nullptr;
      	__preFixInsert(new_node);
      }
      
  • __preFixInsert(node* new_node)
    • 对于新插入的节点 new_node,我们将其设置为红色。这里为什么不设置为黑色呢?如果我们把 new_node 设置为黑色,new_node 所在的分支的黑高将会比其他分支多一,那么我们必须进行修复,但是如果我们把 new_node 设置为红色,是有可能不用进行额外的修复的,所以就设置为红色。
    • __insert 函数保证了调用 __preFixInsert 函数时 new_node 一定不会是根节点。
    • new_node 设置为红色后:
      • 如果 new_node 的父亲节点是黑色的,那么插入操作并没有破坏红黑树的任何性质,我们直接返回。
      • 如果 new_node 的父亲节点是红色的,那么现在我们违反了性质二,需要进行修复,修复交给 __fixInsert(new_node) 函数处理。
    • 实现:
      void __preFixInsert(node* new_node) {
      	new_node->setRed();
      	if (new_node->parent()->isRed())
      		__fixInsert(new_node);
      }
      
  • 左旋和右旋操作
    • 我们以 ".." 代表黑高为 N 的子树,以 "..." 代表黑高为 N + 1 的子树 (N = 0, 1, 2, ...)。这里因为 "..." 子树可能为 nil".." 子树可能连 nil 都不是 (参考 N = 0 的情况),所以我们不再显示地标注 "..""..." 子树的父节点指向线,最下面的 NN + 1N + 2 等代表对应分支的黑高。
    • 对图中 pos 节点右旋 (pos 也可以是黑色的,左旋和右旋操作本身并不关心节点颜色):
      在这里插入图片描述
    • 对图中 pos 节点左旋(pos 也可以是黑色的,左旋和右旋操作本身并不关心节点颜色):
      在这里插入图片描述
    • 左旋和右旋后有可能需要更新 root(),这种可能的情形我们在左旋和右旋函数中处理,另外,左旋和右旋操作后 leftmost()rightmost() 一定不需要更新 (因为树中存储的节点的集合没有变化,leftmost()rightmost() 本质其实就是指向树中的最小值和最大值的节点,集合没有变化,那显然最小值和最大值所在的节点也不会变)。
  • __fixInsert(node* new_node)
    • new_node 的父节点为 parent,爷爷节点为 gparent,叔叔节点为 uncle
    • 现有条件:
      1. new_node 一定不是根节点 (由 __insert 函数保证)。
      2. new_node 一定是红色的 (由 __preFixInsert 函数保证)。
      3. parent 一定是红色的 (由 __preFixInsert 函数保证)。
      4. parent 一定不是根节点 (由性质一 + 现有条件3 保证)。
      5. gparent 一定是黑色的 (由性质二 + 现有条件3 保证)。
      6. new_nodeparentgparent 都不会是头节点 (易知)。
    • 算法递归时,我们同样保证上述六个条件成立。
    • 分类讨论:(这里假定 parentgparent 的左孩子,如果 parentgparent 的右孩子,所有操作都是下面讨论的镜像操作,如何进行镜像后文会给出。)
      • case 1:uncle 是黑色的
        • case 1.1:new_nodeparent 的左孩子
          在这里插入图片描述
        • case 1.2:new_nodeparent 的右孩子
          在这里插入图片描述
      • case 2:uncle 是红色的。case 2 中不再需要区分 new_nodeparent 的左孩子还是右孩子,这里我们绘图以 new_nodeparent 的左孩子为例:
        • case 2.1:gparent 已经是根节点了
          在这里插入图片描述
        • case 2.2:gparent 不是根节点
          这里仅给出 gparent 的父亲节点是红色时的修复示意图。如果gparent 的父亲节点是黑色的,修复操作是一样的,但是不用继续向上递归了,因为修复已经完成了。
          在这里插入图片描述
    • 如果 parentgparent 的右孩子,所有操作都是 parentgparent 的左孩子的操作的镜像操作,在代码实现中只需要复制 parentgparent 的左孩子时的代码,然后将 ->lchild() 全部修改为 ->rchild(),将 ->rchild() 全部修改为 ->lchild(),将所有左旋的地方改为右旋,将所有右旋的地方改为左旋就好了。
    • 实现:
      void __fixInsert(node* new_node) {
      	node* parent = new_node->parent();
      	node* gparent = new_node->gparent();
      	node* uncle = new_node->uncle();
      	while (true) {
      		if (parent == gparent->lchild()) {
      			if (uncle == nullptr || uncle->isBlk()) {	// case 1
      				if (new_node == parent->rchild()) { // case 1.2
      					__leftRotate(parent);
      					new_node = parent;
      					// 更新各个亲戚,注意到爷爷和叔叔都没变不用更新,只需要更新父亲。
      					parent = new_node->parent();
      					// 现在转换为了 case 1.1。
      				}
      				// case 1.1
      				parent->setBlk();
      				gparent->setRed();
      				__rightRotate(gparent);
      				return; // 修复结束,函数返回。
      			} else {	// case 2
      				// case 2.1 和 case 2.2 的公共操作
      				parent->setBlk();
      				uncle->setBlk();
      										
      				if (gparent == root()) {	// case 2.1
      					return;	// 修复结束,函数返回。
      				} else {	// case 2.2
      					gparent->setRed();
      					if (gparent->parent()->isBlk())
      						return;	// 修复结束,函数返回。
      					new_node = gparent;
      					// 更新各个亲戚,父亲,爷爷,叔叔都变了都要更新。
      					parent = new_node->parent();
      					gparent = new_node->gparent();
      					uncle = new_node->uncle();
      					
      					continue;	// 继续向上修复,进入下一轮 while 循环
      				}
      			}
      		} else { // 上面过程的镜像操作
      			if (uncle == nullptr || uncle->isBlk()) {
      				if (new_node == parent->lchild()) {
      					__rightRotate(parent);
      					new_node = parent;
      					parent = new_node->parent();
      				}
      				parent->setBlk();
      				gparent->setRed();
      				__leftRotate(gparent);
      				return;
      			} else {
      				parent->setBlk();
      				uncle->setBlk();
      				if (gparent == root()) {
      					return;
      				} else {
      					gparent->setRed();
      					if (gparent->parent()->isBlk())
      						return;
      					new_node = gparent;
      					parent = new_node->parent();
      					gparent = new_node->gparent();
      					uncle = new_node->uncle();
      					continue;
      				}
      			}
      		}
      	}
      	return; // never reach.
      }
      

删除节点

  • erase(const T& value)
    • 根据想要删除的值找到这个值所在的节点,如果这个值不在当前树中,直接返回 false,否则进入下一步删除这个值所在的节点。
    • 实现:
      bool erase(const T& value) {
      	node* pos = root();	// pos 初始化为根节点
      	while (pos != nullptr) {
      		if (value == pos->data()) {
      			__erase(pos);
      			--m_node_count;
      			return true;
      		} 
      		if (value < pos->data())
      			pos = pos->lchild();
      		else
      			pos = pos->rchild();
      	}
      	// 退出循环还是没找到,返回 false
      	return false;
      }
      
  • __erase(node* pos)
    • 找到值为 value 的节点 pos 后,我们要把 pos 节点从树中删除。我们先直接删除,如果破坏了某些性质,再进行修复。此时的直接删除,和二叉排序树的删除是一模一样的,但是由于我们的实现添加了一个头节点,需要注意 root()leftmost()rightmost() 是否需要更新。
    • case 1:如果 pos 的左右孩子都是 nil,也就是说 pos 是普通二叉树意义下的叶子节点,那么删除操作是非常简单的。
    • case 2:如果 pos 的左孩子是 nil,右孩子不是 nil,这种情况也很简单,让 pos 的右孩子顶替 pos 的位置,然后删除 pos 就可以了。
    • case 3:如果 pos 的右孩子是 nil,左孩子不是 nil,这种情况是 case 2 的镜像,操作类似。
    • case 4:如果 pos 的左右孩子都不是 nil,我们先找到 pos 中序遍历的下一个节点,记这个节点为 successor,然后将 pos 的值修改为 successor 的值,此时我们就把删除 pos 等效地变换为了删除 successor,而 successor 的左孩子一定是 nil (否则真正的 successor 应该是在现在这个 successor 的左子树里面),因此删除 successor 一定满足 case 1case2。我们按照 case 1case 2 的方法删除 successor 即可。
    • 推论:对于 case 2case 3 我们可以推论得到 pos 一定是黑色的,pos 的非 nil 的那个孩子一定是红色的,而且这个孩子的左右孩子一定是 nil (根据红黑树的性质二和性质三很容易推理得到)。
    • 现在我们已经完成直接删除步骤了,注意我们最后真正删除的节点一定只可能是 case 1case 2case3,我们接下来分析这三种情形下什么时候需要进行修复。
      • 对于 case 1,如果 pos 是红色节点,删除后不破坏任何性质,无需修复,直接返回;如果 pos 是黑色节点,删除 pos 后会破坏性质三,而且恰好是使得这个分支的黑高少了一(因为我们只删除了一个节点),此时我们需要调用 __preFixErase 进行修复。
      • 对于 case 2case 3,根据前面的推论,pos 一定是黑色的,删除 pos 后会破坏性质三,而且恰好是使得这个分支的黑高少了一(因为我们只删除了一个节点),此时我们需要调用 __preFixErase 进行修复。
    • 实现:
      void __erase(node* pos) {
      	// pos 一定不会是 nullptr,由调用者 erase(const T& value) 保证。
      	node* successor; // pos 删除后的接盘者
      	if (pos->lchild() == nullptr && pos->rchild() == nullptr) {	// case 1
      		successor = nullptr;
      		if (pos == root()) {
      			root() = nullptr;
      			leftmost() = head();
      			rightmost() = head();
      		} else {
      			if (pos == pos->parent()->lchild())
      				pos->parent()->lchild() = successor;
      			else
      				pos->parent()->rchild() = successor;
      			
      			if (pos == leftmost())
      				leftmost() = pos->parent();
      			if (pos == rightmost())
      				rightmost() = pos->parent();
      			
      			node* suc_parent = pos->parent(); // successor 的父亲
      			if (pos->isBlk()) // 不需要判断 pos 可能为 nullptr 的情况,调用者 erase 函数保证了 pos 不会是 nullptr
      				__preFixErase(successor, suc_parent);	// 此时调用 __preFixErase,successor 一定不会是根节点
      		}
      	} else if (pos->lchild() != nullptr && pos->rchild() != nullptr) {	// case 4
      		// __get_successor()函数找到 pos 中序遍历的下一个节点
      		successor = __get_successor(pos);
      		pos->data() = std::move(successor->data());
      		// 现在我们将 __erase(pos) 等价地转换为了 __erase(successor).
      		__erase(successor); // 递归调用,根据上文分析此处递归只会递归一次。
      		return;
      	} else {	// case 2 和 case 3
      		if (pos->lchild() == nullptr)
      			successor = pos->rchild();
      		else
      			successor = pos->lchild();
      		
      		// 注意我们前面的推论在这里有用:根据推论,
      		// successor 一定不是 nil,successor 一定是红色的,
      		// successor 的左右孩子一定是 nil。
      		// pos 一定是黑色的。
      
      		if (pos == root()) {
      			root() = successor;
      			leftmost() = successor;
      			rightmost() = successor;
      			successor->setBlk();
      		} else {
      			if (pos == pos->parent()->lchild())
      				pos->parent()->lchild() = successor;
      			else
      				pos->parent()->rchild() = successor;
      			successor->parent() = pos->parent(); // 不用担心successor可能是 nullptr,不理解的话参考上面的推论
      			if (pos == leftmost())
      				leftmost() = successor;
      			if (pos == rightmost())
      				rightmost() = successor;
      			
      			node* suc_parent = successor->parent();
      			// 此时调用 __preFixErase,successor 一定不会是根节点
      			__preFixErase(successor, suc_parent);
      		}
      	}
      	destroy_node(pos);
      }
      
  • __preFixErase(node* successor, node* suc_parent)
    • 现有条件:successor 一定不是根节点(由调用者 __erase 函数保证)。
    • 此时,如果 successor 是红色的,我们直接将 successor 染黑,性质三就满足了,此时修复完成,直接返回。而如果 successor 是黑色的,我们则需要继续调用 __fixErase 函数进行修复。
    • 实现:
      void __preFixErase(node* successor, node* suc_parent) {
      	if (successor != nullptr && successor->isRed()) {
      		successor->setBlk();
      		return;
      	}
      	__fixErase(successor, suc_parent);
      }
      
  • __fixErase(node* successor, node* suc_parent)
    • 这里为什么还要传参 suc_parent ?,你可能会觉得 successor 的父亲节点可以直接通过 successor->parent() 得到,但注意 __erase 函数并不保证调用 __fixErasesuccessor 不是 nullptr,所以我们不得不将 successor 的父亲节点通过参数传递进来。
    • 我们用 parentbrother 代表 successor 的父亲节点和兄弟节点。
    • 现有条件:
      1. successor 一定是黑色节点 (由 __preFixErase 函数保证)。
      2. successor 一定不是根节点 (由 __erase 函数保证)。
      3. successor 所在分支黑高恰好比其他分支少 1 (由 __erase 函数保证)。
      4. brother 一定不是 nil (由现有条件三保证)。
      5. successorparentbrother 都不会是头节点 (易知)。
      6. 除了性质三,当前树满足其他三条性质。
    • 算法递归时,我们同样保证上述六个条件成立。
    • 后文图中蓝色的结点表示这个结点既可以是红色的,也可以是黑色的。
    • 分类讨论 (仅讨论 successorparent 的左孩子的情况,successorparent 的右孩子时是镜像操作):
      • case 1:brother 是红色的。注意此时 parent 一定是黑色 (由性质二)。
        在这里插入图片描述
      • case 2:brother 是黑色的
        • case 2.1:brother 的左右孩子都是黑色的
          在这里插入图片描述
        • case 2.2:brother 的左孩子是红色的,右孩子是黑色的
          在这里插入图片描述
        • case 2.3:brother 的左孩子是黑色的,右孩子是红色的
          在这里插入图片描述
        • case 2.4:brother 的左右孩子都是红色的
          在这里插入图片描述
    • successorparent 的右孩子的操作完全是 successorparent 的左孩子的操作的镜像操作,不再赘述。
    • 实现:
    void __fixErase(node* successor, node* suc_parent) {
    	node* parent = suc_parent;
    	node* brother;
    	while (true) {
    		if (successor == parent->lchild()) {
    			brother = parent->rchild();
    			if (brother->isRed()) {	// case 1 ⇒ 转换为 case 2
    				parent->setRed();
    				brother->setBlk();
    				__leftRotate(parent);
    				// 更新 brother,
    				// 注意,不可以直接使用:brother = successor->brother(),
    				// 因为 successor 有可能为 nullptr,
    				// 也就是说 successor 指向的节点可能是 nil。 
    				brother = parent->rchild();
    			}
    			// case 2
    			if ((brother->lchild() == nullptr || brother->lchild()->isBlk()) 
    			&& (brother->rchild() == nullptr || brother->rchild()->isBlk())) {
    				// case 2.1
    				brother->setRed();
    				if (parent->isRed()) {
    					parent->setBlk();
    					return;
    				}
    				if (parent == root())
    					return;
    				successor = parent;
    				// 更新各个亲戚,这里可以直接使用 successor-> 操作,
    				// 因为这里 successor 一定不会是 nullptr。
    				parent = successor->parent();
    				brother = successor->brother();
    				continue;
    			} else {
    				if (brother->rchild() == nullptr || brother->rchild()->isBlk()) {
    					// case 2.2 ⇒ 转换为 case 2.3 or case 2.4
    					brother->setRed();
    					brother->lchild()->setBlk(); // 此时 brother->lchild() 一定不会是 nullptr
    					__rightRotate(brother);
    					// 刷新 brother,同上因为 successor 可能为 nullptr,
    					// 不能使用 brother = successor->brother() 进行更新。
    					brother = parent->rchild();					}
    				// case 2.3 和 case 2.4
    				brother->color() = parent->color();
    				parent->setBlk();
    				brother->rchild()->setBlk(); // 此时 brother->rchild() 一定不会是 nullptr
    				__leftRotate(parent);
    				return;
    			}
    		} else {	// mirror operation
    			brother = parent->lchild();
    			if (brother->isRed()) {	// case 1 ⇒ 转换为 case 2
    				parent->setRed();
    				brother->setBlk();
    				__rightRotate(parent);
    				brother = parent->lchild();
    			}
    			// case 2
    			if ((brother->rchild() == nullptr || brother->rchild()->isBlk()) 
    			&& (brother->lchild() == nullptr || brother->lchild()->isBlk())) {
    				// case 2.1
    				brother->setRed();
    				if (parent->isRed()) {
    					parent->setBlk();
    					return;
    				}
    				if (parent == root())
    					return;
    				successor = parent;
    				// 更新各个亲戚
    				parent = successor->parent();
    				brother = successor->brother();
    				continue;
    			} else {
    				if (brother->lchild() == nullptr || brother->lchild()->isBlk()) {
    					// case 2.2 ⇒ 转换为 case 2.3 or case 2.4
    					brother->setRed();
    					brother->rchild()->setBlk(); // 此时 brother->lchild() 一定不会是 nullptr
    					__leftRotate(brother);
    					brother = parent->lchild(); // 刷新 brother
    				}
    				// case 2.3 和 case 2.4
    				brother->color() = parent->color();
    				parent->setBlk();
    				brother->lchild()->setBlk(); // 此时 brother->rchild() 一定不会是 nullptr
    				__rightRotate(parent);
    				return;
    			}
    		}
    	}
    	return; // never reach
    }
    

c++代码实现

  • 按照前文所述思路的实现
#pragma once

#include <sstream>
#include <iostream>

enum class rbcolor { red = false, blk = true };

struct rbnode_base {
  rbcolor m_color;
  rbnode_base* m_parent;
  rbnode_base* m_lchild;
  rbnode_base* m_rchild;
};

template <typename T>
struct rbnode : public rbnode_base {
  T m_data;

  T& data() const { return (T&)m_data; }

  rbcolor& color() const { return (rbcolor&)m_color; }

  rbnode*& parent() const { return (rbnode*&)m_parent; }

  rbnode*& lchild() const { return (rbnode*&)m_lchild; }

  rbnode*& rchild() const { return (rbnode*&)m_rchild; }

  rbnode* brother() const {
    if (this == parent()->lchild())
      return parent()->rchild();
    else
      return parent()->lchild();
  }

  rbnode* gparent() const { return parent()->parent(); }

  rbnode* uncle() const {
    rbnode* father = parent();
    rbnode* grandfather = gparent();
    if (father == grandfather->lchild())
      return grandfather->rchild();
    else
      return grandfather->lchild();
  }

  bool isRed() const { return m_color == rbcolor::red; }

  bool isBlk() const { return m_color == rbcolor::blk; }

  void setRed() { m_color = rbcolor::red; }

  void setBlk() { m_color = rbcolor::blk; }

};

template <typename T>
class rbtree {
 protected:
  using node = rbnode<T>;

  rbnode_base m_head;
  size_t m_node_count;

 public:
  rbtree() { reset(); }

  ~rbtree() { clear(); }

 public:
  bool insert(const T& value) {
    node* pos = root();
    node* pos_parent = head();
    while (pos != nullptr) {
      pos_parent = pos;
      if (value == pos->data())
        return false;
      else if (value < pos->data())
        pos = pos->lchild();
      else
        pos = pos->rchild();
    }
    __insert(pos_parent, value);
    ++m_node_count;
    return true;
  }

  bool erase(const T& value) {
    node* pos = root();
    while (pos != nullptr) {
      if (value == pos->data()) {
        __erase(pos);
        --m_node_count;
        return true;
      } 
      if (value < pos->data())
        pos = pos->lchild();
      else
        pos = pos->rchild();
    }
    return false;
  }

  void clear() {
    if (!empty()) {
      __destroy(root());
      reset();
    }
  }

  size_t size() const
  { return m_node_count; }

  bool empty() const
  { return size() == 0; }

  void disp() const {
    if (!empty())
      __disp(root());
    std::cout << std::endl;
  }

 protected:
  node* head() const
  { return (node*)(&m_head); }

  node*& root() const
  { return (node*&)m_head.m_parent; }

  node*& leftmost() const
  { return (node*&)m_head.m_lchild; }

  node*& rightmost() const
  { return (node*&)m_head.m_rchild; }

  node* create_node(const T& x) { 
    node* tmp = (node*)malloc(sizeof(node));
    try {
       new (&tmp->m_data) T(x);
    } catch (...) {
      free((void*)tmp);
    }
    return tmp;
  }

  void destroy_node(node* p) {
    (&p->m_data)->~T();
    free(p);
  }
  
  void reset() {
    m_head.m_color = rbcolor::red;
    root() = nullptr;
    leftmost() = head();
    rightmost() = head();
    m_node_count = 0;
  }

 protected:
  void __leftRotate(node* x) {
    node* rchild = x->rchild();
    x->rchild() = rchild->lchild();
    if (rchild->lchild() != nullptr)
      rchild->lchild()->parent() = x;
    if (x == root())
      root() = rchild;
    else if (x == x->parent()->lchild())
      x->parent()->lchild() = rchild;
    else 
      x->parent()->rchild() = rchild;
    rchild->parent() = x->parent();
    rchild->lchild() = x;
    x->parent() = rchild;
  }

  void __rightRotate(node* x) {
    node* lchild = x->lchild();
    x->lchild() = lchild->rchild();
    if (lchild->rchild() != nullptr)
      lchild->rchild()->parent() = x;
    if (x == root())
      root() = lchild;
    else if (x == x->parent()->lchild())
      x->parent()->lchild() = lchild;
    else 
      x->parent()->rchild() = lchild;
    lchild->parent() = x->parent();
    lchild->rchild() = x;
    x->parent() = lchild;
  }

  node* __get_successor(node* x) {
    if (x == head()) {
      x = x->lchild();
    } else if (x->rchild() != nullptr) {
      node* y = x->rchild();
      while (y->lchild() != nullptr)
        y = y->lchild();
      x = y;
    } else {
      node* y = x->parent();
      while (x == y->rchild()) {
        x = y;
        y = y->parent();
      }
      /**
       * 假如最开始传入的参数 x 是根节点,并且 x 的右孩子为空,
       * 我们期望 x 的 successor 应该是头节点 (类似于循环链表)。
       * 下面这个 if 语句就是为了保证这种期望的。
       */
      if (x->rchild() != y)
        x = y;
    }
    return x;
  }

  void __disp(node* rt) const {
    if (rt->isRed())
      std::cout << "\033[1;31m" << rt->m_data;
    else 
      std::cout << "\033[1;34m" << rt->m_data;
    std::cout << "\033[0m";
    if (rt->lchild() || rt->rchild()) {
      std::cout << '(';
      if (rt->lchild())
        __disp(rt->lchild());
      if (rt->rchild()) {
        std::cout << ", ";
        __disp(rt->rchild());
      }
      std::cout << ')';
    }
  }

  void __destroy(node* pos) {
    while (pos != nullptr) {
      __destroy(pos->rchild());
      node* tmp = pos->lchild();
      destroy_node(pos);
      pos = tmp;
    }
  }

  void __insert(node* pos_parent, const T& value) {
    node* new_node = create_node(value);
    if (pos_parent == head()) {
      leftmost() = rightmost() = root() = new_node;
      new_node->setBlk();
      new_node->parent() = pos_parent;
      new_node->lchild() = nullptr;
      new_node->rchild() = nullptr;
      return;
    } else if (value < pos_parent->data()) {
      pos_parent->lchild() = new_node;
      if (pos_parent == leftmost())
        leftmost() = new_node;
    } else {
      pos_parent->rchild() = new_node;
      if (pos_parent == rightmost())
        rightmost() = new_node;
    }
    new_node->parent() = pos_parent;
    new_node->lchild() = nullptr;
    new_node->rchild() = nullptr;
    __preFixInsert(new_node);
  }

  void __erase(node* pos) {
    node* successor;
    if (pos->lchild() == nullptr && pos->rchild() == nullptr) {
      successor = nullptr;
      if (pos == root()) {
        root() = nullptr;
        leftmost() = head();
        rightmost() = head();
      } else {
        if (pos == pos->parent()->lchild())
          pos->parent()->lchild() = successor;
        else
          pos->parent()->rchild() = successor;
        
        if (pos == leftmost())
          leftmost() = pos->parent();
        if (pos == rightmost())
          rightmost() = pos->parent();
        
        if (pos->isBlk()) {
          node* suc_parent = pos->parent();
          __preFixErase(successor, suc_parent);
        }
      }
    } else if (pos->lchild() != nullptr && pos->rchild() != nullptr) {
      successor = __get_successor(pos);
      pos->data() = std::move(successor->data());
      __erase(successor);
      return;
    } else {
      if (pos->lchild() == nullptr)
        successor = pos->rchild();
      else
        successor = pos->lchild();

      if (pos == root()) {
        root() = successor;
        leftmost() = successor;
        rightmost() = successor;
        successor->setBlk();
      } else {
        if (pos == pos->parent()->lchild())
          pos->parent()->lchild() = successor;
        else
          pos->parent()->rchild() = successor;
        successor->parent() = pos->parent();
        if (pos == leftmost())
          leftmost() = successor;
        if (pos == rightmost())
          rightmost() = successor;
        
        node* suc_parent = successor->parent();
        __preFixErase(successor, suc_parent);
      }
    }
    destroy_node(pos);
  }

  void __preFixInsert(node* new_node) {
    new_node->setRed();
    if (new_node->parent()->isRed())
      __fixInsert(new_node);
  }

  void __preFixErase(node* successor, node* suc_parent) {
    if (successor != nullptr && successor->isRed()) {
      successor->setBlk();
      return;
    }
    __fixErase(successor, suc_parent);
  }

  void __fixInsert(node* new_node) {
    node* parent = new_node->parent();
    node* gparent = new_node->gparent();
    node* uncle = new_node->uncle();
    while (true) {
      if (parent == gparent->lchild()) {
        if (uncle == nullptr || uncle->isBlk()) {
          if (new_node == parent->rchild()) {
            __leftRotate(parent);
            new_node = parent;
            parent = new_node->parent();
          }
          parent->setBlk();
          gparent->setRed();
          __rightRotate(gparent);
          return;
        } else {
          parent->setBlk();
          uncle->setBlk();
                      
          if (gparent == root())
            return;
          gparent->setRed();
          if (gparent->parent()->isBlk())
            return;
          new_node = gparent;
          parent = new_node->parent();
          gparent = new_node->gparent();
          uncle = new_node->uncle();
          continue;
        }
      } else {
        if (uncle == nullptr || uncle->isBlk()) {
          if (new_node == parent->lchild()) {
            __rightRotate(parent);
            new_node = parent;
            parent = new_node->parent();
          }
          parent->setBlk();
          gparent->setRed();
          __leftRotate(gparent);
          return;
        } else {
          parent->setBlk();
          uncle->setBlk();
          if (gparent == root())
            return;
          gparent->setRed();
          if (gparent->parent()->isBlk())
            return;
          new_node = gparent;
          parent = new_node->parent();
          gparent = new_node->gparent();
          uncle = new_node->uncle();
          continue;
        }
      }
    }
    return; // never reach.
  }

  void __fixErase(node* successor, node* suc_parent) {
    node* parent = suc_parent;
    node* brother;
    while (true) {
      if (successor == parent->lchild()) {
        brother = parent->rchild();
        if (brother->isRed()) {
          parent->setRed();
          brother->setBlk();
          __leftRotate(parent);
          brother = parent->rchild();
        }
        if ((brother->lchild() == nullptr || brother->lchild()->isBlk()) 
         && (brother->rchild() == nullptr || brother->rchild()->isBlk())) {
          brother->setRed();
          if (parent->isRed()) {
            parent->setBlk();
            return;
          }
          if (parent == root())
            return;
          successor = parent;
          parent = successor->parent();
          brother = successor->brother();
          continue;
        } else {
          if (brother->rchild() == nullptr || brother->rchild()->isBlk()) {
            brother->setRed();
            brother->lchild()->setBlk();
            __rightRotate(brother);
            brother = parent->rchild();
          }
          brother->color() = parent->color();
          parent->setBlk();
          brother->rchild()->setBlk();
          __leftRotate(parent);
          return;
        }
      } else {
        brother = parent->lchild();
        if (brother->isRed()) {
          parent->setRed();
          brother->setBlk();
          __rightRotate(parent);
          brother = parent->lchild();
        }
        if ((brother->rchild() == nullptr || brother->rchild()->isBlk()) 
         && (brother->lchild() == nullptr || brother->lchild()->isBlk())) {
          brother->setRed();
          if (parent->isRed()) {
            parent->setBlk();
            return;
          }
          if (parent == root())
            return;
          successor = parent;
          parent = successor->parent();
          brother = successor->brother();
          continue;
        } else {
          if (brother->lchild() == nullptr || brother->lchild()->isBlk()) {
            brother->setRed();
            brother->rchild()->setBlk();
            __leftRotate(brother);
            brother = parent->lchild();
          }
          brother->color() = parent->color();
          parent->setBlk();
          brother->lchild()->setBlk();
          __rightRotate(parent);
          return;
        }
      }
    }
    return; // never reach.
  }
  
};

更完整的c++代码实现

  • 相较于之前的实现,增加了内存池,允许插入值相同的结点等细节,其余地方代码也有略微调整。
#pragma once

#include <atomic>
#include <memory>
#include <sstream>
#include <iostream>

#define __TINY_MEMPOOL__

class tiny_mempool {
 protected:
 struct memNode { memNode *nextnode = nullptr; };

 protected:
  std::atomic<memNode*> m_free_head[16];

 private:
  tiny_mempool() {}

  ~tiny_mempool()
  { for (int i = 0; i < 16; i++)
    { if (m_free_head[i] != nullptr)
      { memNode *ptr = m_free_head[i];
        while (ptr != nullptr)
        { auto nptr = ptr->nextnode;
          free(ptr);
          ptr = nptr;
        }
      }
      m_free_head[i] = nullptr;
    }
  }

  int getindex(int size)
  { static const unsigned int sizetable[16]
    = { 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128 };
    int __distance = 0;
    for (; __distance < 16; __distance++)
      if (sizetable[__distance] >= size)
        break;
    return __distance;
  }

 public:
  static tiny_mempool &instance()
  { static tiny_mempool pool;
    return pool;
  }

 public:
  void *alloc(int size)
  { if (size > 128) return malloc(size);
    int index = getindex(size);
    int realsize = (index + 1) << 3;
    memNode *p = m_free_head[index];
    if (p == nullptr)
      return malloc(realsize);
    else
    { while (!m_free_head[index].compare_exchange_weak(p, p->nextnode))
        if (p == nullptr) return malloc(realsize);
      return p;
    }
    return nullptr;
  }

  void delloc(void *ptr, int size)
  { if (ptr == nullptr) return;
    if (size > 128) return free(ptr);
    int index = getindex(size);
    memNode *pNew = (memNode *)ptr;
    pNew->nextnode = m_free_head[index];
    while (!(m_free_head[index].compare_exchange_weak(pNew->nextnode, pNew)))
    ;
  }
  
  /**
   * @brief report memory distribute in the pool.
   * @attention May cause undefined result if 
   * allocate memory use current pool before this
   * function return.
   */
  void report()
  { printf("\033[32m\033[1mtiny_mempool report\033[0m\n");
    printf("\033[34mindex\tnode size   node count\033[0m\n");
    for (int i = 0; i < 16; ++i)
    { int n = 0;
      memNode *p = m_free_head[i];
      while (p)
      { n++;
        p = p->nextnode;
      }
      printf("\033[31m%5d\t %3d \033[35mbyte\033[31m   %10d"
             "\033[0m\n", i, (i + 1) << 3, n);
    }
  }

};

template<class T>
class tiny_allocator { 
 public:
  using value_type = T;
  using pointer = T*;
  using const_pointer = const T*;
  using reference = T&;
  using const_reference = const T&;
  using size_type = size_t;

  tiny_allocator() {}

  tiny_allocator(tiny_allocator const &) {}

  tiny_allocator &operator=(tiny_allocator const &)
  { return *this; }

  template<class Other>
  tiny_allocator(tiny_allocator<Other> const &) {}

  template<class Other>
  tiny_allocator &operator=(tiny_allocator<Other> const &)
  { return *this; }

  pointer allocate(size_type count)
  { return (pointer)tiny_mempool::instance()
      .alloc(count * sizeof(value_type));
  }

  void deallocate(pointer ptr, size_type count)
  { return tiny_mempool::instance()
      .delloc(ptr, count * sizeof(value_type));
  }
};

template <typename _Tp1, typename _Tp2>
std::ostream& operator<<(std::ostream& os, 
  const std::pair<_Tp1, _Tp2>& pair) {
  return os << '{' << pair.first << ", " << pair.second << '}';
}

enum class rbcolor { red = false, blk = true };

struct rbnode_base {
  rbcolor m_color;
  rbnode_base* m_parent;
  rbnode_base* m_lchild;
  rbnode_base* m_rchild;
};

template <typename T>
struct rbnode : public rbnode_base {
  T m_data;

  T& data() const { return (T&)m_data; }

  rbcolor& color() const { return (rbcolor&)m_color; }

  rbnode*& parent() const { return (rbnode*&)m_parent; }

  rbnode*& lchild() const { return (rbnode*&)m_lchild; }

  rbnode*& rchild() const { return (rbnode*&)m_rchild; }

  rbnode*& brother() const {
    if (this == parent()->lchild())
      return parent()->rchild();
    else
      return parent()->lchild();
  }

  rbnode*& gparent() const { return parent()->parent(); }

  rbnode*& uncle() const {
    rbnode* father = parent();
    rbnode* grandfather = gparent();
    if (father == grandfather->lchild())
      return grandfather->rchild();
    else
      return grandfather->lchild();
  }

  bool isRed() const { return m_color == rbcolor::red; }

  bool isBlk() const { return m_color == rbcolor::blk; }

  void setRed() { m_color = rbcolor::red; }

  void setBlk() { m_color = rbcolor::blk; }

  static rbnode* prev(rbnode* x) {
    if (x->isRed() && x->gparent() == x) {
      x = x->rchild();
    } else if (x->lchild() != nullptr) {
      rbnode* y = x->lchild();
      while (y->rchild() != nullptr)
        y = y->rchild();
      x = y;
    } else {
      rbnode* y = x->parent();
      while (x == y->lchild()) {
        x = y;
        y = y->parent();
      }
      if (x->lchild() != y)
        x = y;
    }
    return x;
  }

  static rbnode* next(rbnode* x) {
    if (x->isRed() && x->gparent() == x) {
      x = x->lchild();
    } else if (x->rchild() != nullptr) {
      rbnode* y = x->rchild();
      while (y->lchild() != nullptr)
        y = y->lchild();
      x = y;
    } else {
      rbnode* y = x->parent();
      while (x == y->rchild()) {
        x = y;
        y = y->parent();
      }
      if (x->rchild() != y)
        x = y;
    }
    return x;
  }

};

template <typename T, typename Alloc=tiny_allocator<T>>
class rbtree {
 protected:
  template <typename _Tp, typename _Up>
  struct alloc_rebind {};

  template <template <typename, typename...> class _Template,
            typename _Up, typename _Tp, typename... _Types>
  struct alloc_rebind<_Template<_Tp, _Types...>, _Up>
  { using type = _Template<_Up, _Types...>; };

  using node = rbnode<T>;
  using allocator_type = typename alloc_rebind<Alloc, node>::type;

  struct rbtree_impl : public allocator_type
  { rbnode_base m_head;
    size_t m_node_count;
  };

  rbtree_impl m_impl;

 public:
  rbtree() { reset(); }

  rbtree(const rbtree& tree) { 
    if (tree.empty()) {
      reset();
      return;
    }
    m_impl.m_head.m_color = tree.m_impl.m_head.m_color;
    m_impl.m_node_count = tree.m_impl.m_node_count;
    root() = copyfrom(tree.root());
    root()->parent() = head();
    node* __min = root();
    node* __max = root();
    while (__min->lchild() != nullptr)
      __min = __min->lchild();
    while (__max->rchild() != nullptr)
      __max = __max->rchild();
    leftmost() = __min;
    rightmost() = __max;
  }

  rbtree(rbtree&& tree) {
    if (!tree.empty())
      movefrom(tree);
    else
      reset();
  }

  ~rbtree() { clear(); }

 public:
  bool insert(const T& value) 
  { return insert_unique(value); }

  bool insert_unique(const T& value) {
    node* pos = root();
    node* pos_parent = head();
    while (pos != nullptr) {  
      pos_parent = pos;
      if (less(value, pos->data()))
        pos = pos->lchild();
      else if (less(pos->data(), value))
        pos = pos->rchild();
      else
        return false;
    }
    __insert(pos_parent, value);
    ++m_impl.m_node_count;
    return true;
  }

  void insert_equal(const T& value) {
    node* pos = root();
    node* pos_parent = head();
    while (pos != nullptr) {  
      pos_parent = pos;
      if (less(value, pos->data()))
        pos = pos->lchild();
      else
        pos = pos->rchild();
    }
    __insert(pos_parent, value);
    ++m_impl.m_node_count;
  }

  size_t erase(const T& value) {
    node* lb = lower_bound(value);
    node* rb = upper_bound(value);
    size_t n = 0;
    while (lb != rb) {
      node* tmp = lb;
      lb = node::next(lb);
      __erase(tmp);
      ++n;
    }
    m_impl.m_node_count -= n;
    return n;
  }

  void clear() {
    if (!empty()) {
      __destroy(root());
      reset();
    }
  }

  size_t size() const
  { return m_impl.m_node_count; }

  bool empty() const
  { return size() == 0; }

  void disp() const {
    if (!empty())
      __disp(root());
    std::cout << std::endl;
  }

  template <typename U>
  node* lower_bound(const U& v) const {
    node* y = head();
    node* x = root();
    while (x != nullptr) {
      if (!less(x->m_data, v)) {
        y = x;
        x = x->lchild();
      } else {
        x = x->rchild();
      }
    }
    return y;
  }

  template <typename U>
  node* upper_bound(const U& v) const {
    node* y = head();
    node* x = root();
    while (x != nullptr) {
      if (less(v, x->m_data)) {
        y = x;
        x = x->lchild();
      } else {
        x = x->rchild();
      }
    }
    return y;    
  }

  rbtree& operator=(const rbtree& tree) {
    if (this == &tree) return *this;
    clear();
    new (this) rbtree(tree);
    return *this;
  }

  rbtree& operator=(rbtree&& tree) {
    if (this == &tree) return *this;
    clear();
    new (this) rbtree(std::move(tree));
    return *this;
  }

  friend std::ostream& operator<<(std::ostream& os, const rbtree& tree) {
    os << '{';
    auto begin = tree.leftmost();
    auto end = tree.head();
    if (begin != end) {
      os << begin->m_data;
      begin = node::next(begin);
      while (begin != end) {
        os << ", " << begin->m_data;
        begin = node::next(begin);
      }
    }
    return os << '}';
  }

 protected:
  node* head() const
  { return (node*)(&m_impl.m_head); }

  node*& root() const
  { return (node*&)m_impl.m_head.m_parent; }

  node*& leftmost() const
  { return (node*&)m_impl.m_head.m_lchild; }

  node*& rightmost() const
  { return (node*&)m_impl.m_head.m_rchild; }

  node* create_node(const T& x) { 
    node* tmp = m_impl.allocate(1);
    try {
       new (&tmp->m_data) T(x);
    } catch (...) {
      m_impl.deallocate(tmp, 1);
    }
    return tmp;
  }

  void destroy_node(node* p) {
    (&p->m_data)->~T();
    m_impl.deallocate(p, 1);
  }

  node* copyfrom(const node* rt) {
    if (rt == nullptr) return nullptr;
    node* p = create_node(rt->m_data);
    p->m_color = rt->m_color;
    node* lchild = copyfrom((const node*)rt->m_lchild);
    node* rchild = copyfrom((const node*)rt->m_rchild);
    p->lchild() = lchild;
    p->rchild() = rchild;
    if (lchild != nullptr)
      lchild->parent() = p;
    if (rchild != nullptr)
      rchild->parent() = p;
    return p;
  }

  void movefrom(rbtree& x) {
    head()->color() = x.head()->color();
    root() = x.root();
    leftmost() = x.leftmost();
    rightmost() = x.rightmost();
    root()->parent() = head();
    m_impl.m_node_count = x.m_impl.m_node_count;
    x.reset();
  }
  
  void reset() {
    m_impl.m_head.m_color = rbcolor::red;
    root() = nullptr;
    leftmost() = head();
    rightmost() = head();
    m_impl.m_node_count = 0;
  }

  template <typename _Tp>
  static bool less(const _Tp& x, const _Tp& y) 
  { return x < y; }

  template <typename _Tp1, typename _Tp2>
  static bool less(const std::pair<_Tp1, _Tp2>& x, 
            const std::pair<_Tp1, _Tp2>& y) {
    return x.first < y.first;
  }

  template <typename _Tp1, typename _Tp2>
  static bool less(const std::pair<_Tp1, _Tp2>& x, const _Tp1& y) {
    return x.first < y;
  }

  template <typename _Tp1, typename _Tp2>
  static bool less(const _Tp1& x, const std::pair<_Tp1, _Tp2>& y) {
    return x < y.first;
  }

 protected:
  void __leftRotate(node* x) {
    node* rchild = x->rchild();
    x->rchild() = rchild->lchild();
    if (rchild->lchild() != nullptr)
      rchild->lchild()->parent() = x;
    if (x == root())
      root() = rchild;
    else if (x == x->parent()->lchild())
      x->parent()->lchild() = rchild;
    else 
      x->parent()->rchild() = rchild;
    rchild->parent() = x->parent();
    rchild->lchild() = x;
    x->parent() = rchild;
  }

  void __rightRotate(node* x) {
    node* lchild = x->lchild();
    x->lchild() = lchild->rchild();
    if (lchild->rchild() != nullptr)
      lchild->rchild()->parent() = x;
    if (x == root())
      root() = lchild;
    else if (x == x->parent()->lchild())
      x->parent()->lchild() = lchild;
    else 
      x->parent()->rchild() = lchild;
    lchild->parent() = x->parent();
    lchild->rchild() = x;
    x->parent() = lchild;
  }

  void __disp(node* rt) const {
    if (rt->isRed())
      std::cout << "\033[1;31m" << rt->m_data;
    else 
      std::cout << "\033[1;34m" << rt->m_data;
    std::cout << "\033[0m";
    if (rt->lchild() || rt->rchild()) {
      std::cout << '(';
      if (rt->lchild())
        __disp(rt->lchild());
      if (rt->rchild()) {
        std::cout << ", ";
        __disp(rt->rchild());
      }
      std::cout << ')';
    }
  }

  void __destroy(node* pos) {
    while (pos != nullptr) {
      __destroy(pos->rchild());
      node* tmp = pos->lchild();
      destroy_node(pos);
      pos = tmp;
    }
  }

  void __insert(node* pos_parent, const T& value) {
    node* new_node = create_node(value);
    if (pos_parent == head()) {
      leftmost() = rightmost() = root() = new_node;
      new_node->setBlk();
      new_node->parent() = pos_parent;
      new_node->lchild() = nullptr;
      new_node->rchild() = nullptr;
      return;
    } else if (value < pos_parent->data()) {
      pos_parent->lchild() = new_node;
      if (pos_parent == leftmost())
        leftmost() = new_node;
    } else {
      pos_parent->rchild() = new_node;
      if (pos_parent == rightmost())
        rightmost() = new_node;
    }
    new_node->parent() = pos_parent;
    new_node->lchild() = nullptr;
    new_node->rchild() = nullptr;
    __preFixInsert(new_node);
  }

  void __erase(node* pos) {
    node* y = pos;
    node* x = nullptr;
    node* x_parent = nullptr;

    if (y->lchild() == nullptr) {
      x = y->rchild();
    } else {
      if (y->rchild() == nullptr) {
        x = y->lchild();
      } else {
        y = y->rchild();
        while (y->lchild() != nullptr)
          y = y->lchild();
        x = y->rchild();
      }
    }

    if (y != pos) {
      y->lchild() = pos->lchild();
      pos->lchild()->parent() = y;

      if (y != pos->rchild()) {
        x_parent = y->parent();
        if (x != nullptr)
          x->parent() = y->parent();
        y->parent()->lchild() = x;
        y->rchild() = pos->rchild();
        pos->rchild()->parent() = y;
      } else {
        x_parent = y;
      }

      if (pos == root())
        root() = y;
      else if (pos->parent()->lchild() == pos)
        pos->parent()->lchild() = y;
      else
        pos->parent()->rchild() = y;
      y->parent() = pos->parent();
      rbcolor tmp = y->color();
      y->color() = pos->color();
      pos->color() = tmp;
      y = pos;
    } else {
      x_parent = y->parent();
      if (x != nullptr)
        x->parent() = y->parent();
      
      if (pos == root()) {
        root() = x;
      } else {
        if (pos->parent()->lchild() == pos)
          pos->parent()->lchild() = x;
        else
          pos->parent()->rchild() = x;
      }

      if (pos == leftmost()) {
        if (pos->rchild() == nullptr)
          leftmost() = pos->parent();
        else
          leftmost() = x;
      }

      if (pos == rightmost()) {
        if (pos->lchild() == nullptr)
          rightmost() = pos->parent();
        else 
          rightmost() = x;
      }
    }

    if (x == root()) {
      if (x != nullptr)
        x->setBlk();
    } else if (y->isBlk())
      __preFixErase(x, x_parent);
    
    destroy_node(y);
  }

  void __preFixInsert(node* new_node) {
    new_node->setRed();
    if (new_node->parent()->isRed())
      __fixInsert(new_node);
  }

  void __preFixErase(node* successor, node* suc_parent) {
    if (successor != nullptr && successor->isRed()) {
      successor->setBlk();
      return;
    }
    __fixErase(successor, suc_parent);
  }

  void __fixInsert(node* new_node) {
    node* parent = new_node->parent();
    node* gparent = new_node->gparent();
    node* uncle = new_node->uncle();
    while (true) {
      if (parent == gparent->lchild()) {
        if (uncle == nullptr || uncle->isBlk()) {
          if (new_node == parent->rchild()) {
            __leftRotate(parent);
            new_node = parent;
            parent = new_node->parent();
          }
          parent->setBlk();
          gparent->setRed();
          __rightRotate(gparent);
          return;
        } else {
          parent->setBlk();
          uncle->setBlk();
                      
          if (gparent == root())
            return;
          gparent->setRed();
          if (gparent->parent()->isBlk())
            return;
          new_node = gparent;
          parent = new_node->parent();
          gparent = new_node->gparent();
          uncle = new_node->uncle();
          continue;
        }
      } else {
        if (uncle == nullptr || uncle->isBlk()) {
          if (new_node == parent->lchild()) {
            __rightRotate(parent);
            new_node = parent;
            parent = new_node->parent();
          }
          parent->setBlk();
          gparent->setRed();
          __leftRotate(gparent);
          return;
        } else {
          parent->setBlk();
          uncle->setBlk();
          if (gparent == root())
            return;
          gparent->setRed();
          if (gparent->parent()->isBlk())
            return;
          new_node = gparent;
          parent = new_node->parent();
          gparent = new_node->gparent();
          uncle = new_node->uncle();
          continue;
        }
      }
    }
    return; // never reach.
  }

  void __fixErase(node* successor, node* suc_parent) {
    node* parent = suc_parent;
    node* brother;
    while (true) {
      if (successor == parent->lchild()) {
        brother = parent->rchild();
        if (brother->isRed()) {
          parent->setRed();
          brother->setBlk();
          __leftRotate(parent);
          brother = parent->rchild();
        }
        if ((brother->lchild() == nullptr || brother->lchild()->isBlk()) 
         && (brother->rchild() == nullptr || brother->rchild()->isBlk())) {
          brother->setRed();
          if (parent->isRed()) {
            parent->setBlk();
            return;
          }
          if (parent == root())
            return;
          successor = parent;
          parent = successor->parent();
          brother = successor->brother();
          continue;
        } else {
          if (brother->rchild() == nullptr || brother->rchild()->isBlk()) {
            brother->setRed();
            brother->lchild()->setBlk();
            __rightRotate(brother);
            brother = parent->rchild();
          }
          brother->color() = parent->color();
          parent->setBlk();
          brother->rchild()->setBlk();
          __leftRotate(parent);
          return;
        }
      } else {
        brother = parent->lchild();
        if (brother->isRed()) {
          parent->setRed();
          brother->setBlk();
          __rightRotate(parent);
          brother = parent->lchild();
        }
        if ((brother->rchild() == nullptr || brother->rchild()->isBlk()) 
         && (brother->lchild() == nullptr || brother->lchild()->isBlk())) {
          brother->setRed();
          if (parent->isRed()) {
            parent->setBlk();
            return;
          }
          if (parent == root())
            return;
          successor = parent;
          parent = successor->parent();
          brother = successor->brother();
          continue;
        } else {
          if (brother->lchild() == nullptr || brother->lchild()->isBlk()) {
            brother->setRed();
            brother->rchild()->setBlk();
            __leftRotate(brother);
            brother = parent->lchild();
          }
          brother->color() = parent->color();
          parent->setBlk();
          brother->lchild()->setBlk();
          __rightRotate(parent);
          return;
        }
      }
    }
    return; // never reach.
  }
  
};

template <typename T>
class set {
 protected:
  rbtree<T> m_tree;

 public:
  set() = default;

  set(std::initializer_list<T> l) : m_tree() {
    for (auto&& x : l)
      m_tree.insert(x);
  }

  bool insert(const T& x)
  { return m_tree.insert_unique(x); }

  bool erase(const T& x)
  { return m_tree.erase(x) != 0; }

  bool empty() const 
  { return m_tree.empty(); }

  size_t size() const 
  { return m_tree.size(); }

  void clear()
  { m_tree.clear(); }

  void disp() const
  { m_tree.disp(); }

  friend std::ostream& operator<<(std::ostream& os, const set& x)
  { return os << x.m_tree; }

};

template <typename K, typename V>
class map {
 protected:
  rbtree<std::pair<K, V>> m_tree;

 public:
  map() = default;

  map(std::initializer_list<std::pair<K, V>> l) : m_tree() {
    for (auto&& x : l)
      m_tree.insert(x);
  }

  bool insert(const std::pair<K, V>& x)
  { return m_tree.insert_unique(x); }

  bool erase(const std::pair<K, V>& x)
  { return m_tree.erase(x) != 0; }

  bool empty() const 
  { return m_tree.empty(); }

  size_t size() const 
  { return m_tree.size(); }

  void clear()
  { m_tree.clear(); }

  void disp() const
  { m_tree.disp(); }

  friend std::ostream& operator<<(std::ostream& os, const map& x)
  { return os << x.m_tree; }

  V& operator[](const K& key) {
    m_tree.insert_unique(std::pair<K, V>(key, V()));
    rbnode<std::pair<K, V>>* lower = m_tree.lower_bound(key);
    return lower->m_data.second;
  }

  const V& operator[](const K& key) const {
    rbnode<std::pair<K, V>>* lower = m_tree.lower_bound(key);
    if (lower->m_data.first != key) {
      std::ostringstream msg;
      msg << "KeyError: " << key;
      throw std::logic_error(msg.str());
    }
    else
      return lower->m_data.second;
  }

};

测试

  • 开启多个线程,测试插入和删除操作。
#include <mutex>
#include <thread>
#include <vector>
#include <cassert>

#include "rbtree.h"

using namespace std;
using namespace chrono;

#define TEST_EPOCH 1000
#define TEST_COUNT 50
#define TEST_THREAD 20

mutex mtx;

static void test_rbtree() {
  rbtree<int> rbt;
  for (int i = 0; i < TEST_EPOCH; i++) {
    int arr[TEST_COUNT];
    for (int j = 0; j < TEST_COUNT; j++) {
      arr[j] = rand() % TEST_COUNT;
      rbt.insert(arr[j]);
    }
    mtx.lock();
    rbt.disp();
    mtx.unlock();
    for (int j = 0; j < TEST_COUNT; j++) rbt.erase(arr[j]);
    assert(rbt.size() == 0);
  }
}

int main(int argc, const char* argv[]) {
  srand(time(NULL));
  vector<thread> threads;
  for (int i = 0; i < TEST_THREAD; i++) {
    threads.emplace_back(test_rbtree);
    this_thread::sleep_for(milliseconds(1));
  }
  for (auto&& t : threads) t.join();
#ifdef __TINY_MEMPOOL__
  tiny_mempool::instance().report();
#endif
  return 0;
}
  • 附上一份测试截图:
> g++ -std=c++11 test.cc
> valgrind ./a.out

在这里插入图片描述

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值