写完之后感觉红黑树也没有那么可怕了。
1.节点定义
enum { red = 0, black = 1 };
template <typename T>
struct RBTreeNode {
T val;
int color = red;
int repeated = 1;
RBTreeNode* left = nullptr;
RBTreeNode* right = nullptr;
RBTreeNode* parent = nullptr;
RBTreeNode() : val()
{}
RBTreeNode(const T& x) : val(x)
{}
};
template <typename T>
class RBTree {
private:
using Node = RBTreeNode<T>;
Node* head = nullptr;
int count = 0;
public:
RBTree()
{
head = new Node;
}
~RBTree()
{
if (head) {
Release(head);
head = nullptr;
}
}
RBTree(const RBTree&) = delete;
RBTree& operator=(const RBTree&) = delete;
RBTree(RBTree&&) = delete;
RBTree& operator=(RBTree&&) = delete;
.....
};
2.核心算法(接口)
Node* search(const T& x)
{
Node* cur = head->left;
while (cur) {
if (x < cur->val) {
cur = cur->left;
}
else if (x > cur->val) {
cur = cur->right;
}
else { // x == cur->val
return cur;
}
}
return nullptr;
}
void insert(const T& x)
{
Node* cur = InsertLeaf(x);
AdjustAfterInsert(cur);
}
void remove(const T& x)
{
bool removeleft = false;
Node* parent = RemoveLeaf(x, removeleft);
AdjustAfterRemove(parent, removeleft);
}
3.内部实现
3.1 插入
Node* InsertLeaf(const T& x)
{
count++;
Node* target = new Node(x);
if (!head->left) {
head->left = target;
target->parent = head;
target->color = black; // 根节点必须是黑色
return target;
}
Node* cur = head->left;
while (cur) {
if (x < cur->val) {
if (!cur->left) {
cur->left = target;
target->parent = cur;
return target;
}
cur = cur->left;
}
else if (x > cur->val) {
if (!cur->right) {
cur->right = target;
target->parent = cur;
return target;
}
cur = cur->right;
}
else { // x == cur->val
cur->repeated++;
break;
}
}
delete target;
return nullptr;
}
3.2.插入后调整(解决double red问题)
void AdjustAfterInsert(Node* cur)
{
if (!cur || cur == head->left || cur->parent->color == black) {
return;
}
Node* parent = cur->parent;
Node* grand = parent->parent;
assert(cur->color == red && parent != head->left && grand->color == black);
Node* uncle = (parent == grand->left) ? grand->right : grand->left;
if (uncle && uncle->color == red) { // parent和uncle都为红色,直接染黑,然后grand染红
parent->color = black;
uncle->color = black;
if (grand != head->left) { // 非根grand由黑变红后可能造成双红,可视为新插入的红节点,继续向根节点调整
grand->color = red;
AdjustAfterInsert(grand);
}
}
else { // parent为红色,uncle为黑色,可通过单旋转(g/p/u同向)或双旋转(g/p/u不同向)调整
if (parent == grand->left) {
if (cur == parent->right) { // g/p/u为LR(不同向)时,左旋一次p后变成同向
RotateLeft(parent);
parent = cur;
}
RotateRight(grand); // g/p/u为LL(同向)时,右旋一次g后即调整完成
}
else {
if (cur == parent->left) { // g/p/u为RL(不同向)时,右旋一次p后变成同向
RotateRight(parent);
parent = cur;
}
RotateLeft(grand); // g/p/u为RR(同向)时,左旋一次g后即调整完成
}
// 调整后新的grand节点继承原来grand位置节点的颜色(即黑色),新的子节点(g/c)染红
grand->color = red;
parent->color = black; // cur本来就为红色,可以不操作
}
}
3.3.删除
Node* RemoveLeaf(const T& x, bool& removeleft)
{
// 1.叶子结点,直接删除
// 2.仅含有一个子节点,删除后用该子节点(子树)占据自己原先的位置
// 3.含有两个子节点,将自身节点值修改为[左子树的最大节点]或[右子树的最小节点]的值,
// 并删除替代节点,删除操作同1(叶节点)或2(仅含有一个子树)的场景
Node* cur = search(x);
if (!cur) {
return nullptr;
}
count--;
if (cur->repeated > 1) {
cur->repeated--;
return nullptr;
}
Node* parent = cur->parent;
Node* tmp = nullptr;
if (cur->left && cur->right) {
tmp = cur->right;
while (tmp->left) {
tmp = tmp->left;
}
cur->val = tmp->val;
cur->repeated = tmp->repeated;
cur = tmp;
parent = cur->parent;
}
tmp = cur->left ? cur->left : cur->right;
if (tmp) {
tmp->parent = parent;
}
if (cur == parent->left) {
parent->left = tmp;
removeleft = true;
}
else {
parent->right = tmp;
removeleft = false;
}
int color = cur->color;
delete cur; // 这里删除的cur必然至少含有一个空孩子
if (color == red) { // 删除红节点(必然含有两个空孩子)
return nullptr;
}
if (tmp && tmp->color == red) { // 删除黑节点(含有一个空孩子和一个红色孩子),孩子染成黑色
tmp->color = black;
return nullptr;
}
return parent; // 删除黑节点(两个孩子一定均为空),导致路径欠一个黑色节点,后续调整
}
3.4.删除后调整(解决lack black问题)
void AdjustAfterRemove(Node *parent, bool removeleft)
{
if (!parent || parent == head) {
return;
}
Node *brother = removeleft ? parent->right : parent->left;
Node *cur = removeleft ? parent->left : parent->right;
assert((!cur || cur->color == black) && brother);
if (brother->color == red) { // 1.兄弟为红节点,旋转后转换成兄弟为黑节点的情况
assert(parent->color == black);
if (removeleft) {
RotateLeft(parent);
} else {
RotateRight(parent);
}
brother->color = black;
parent->color = red;
AdjustAfterRemove(parent, removeleft);
} else if (brother->color == black) { // 2.兄弟为黑节点(不可能为空)
Node *tmpl = brother->left;
Node *tmpr = brother->right;
int color = parent->color;
bool leftblack = (!tmpl || tmpl->color == black);
bool rightblack = (!tmpr || tmpr->color == black);
if (leftblack && rightblack) { // 2.1.兄弟无红色子节点,兄弟染红色
brother->color = red;
if (color == red) { // 父节点为红,直接染黑
parent->color = black;
} else { // 父节点为黑,路径仍然欠一个黑色节点,可视为黑色父节点被删除,继续向根节点调整
cur = parent;
parent = cur->parent;
AdjustAfterRemove(parent, cur == parent->left);
}
} else if (removeleft) {
if (!rightblack) { // 2.2-1.兄弟有一个或两个红色子节点(nephew),p/b/n同向,单旋转
RotateLeft(parent);
brother->color = color;
parent->color = black;
tmpr->color = black;
} else { // 2.3-1.兄弟有一个红色子节点(nephew),p/b/n不同向,双旋转
RotateRight(brother);
RotateLeft(parent);
tmpl->color = color;
parent->color = black;
brother->color = black;
}
} else {
if (!leftblack) { // 2.2-2.兄弟有一个或两个红色子节点(nephew),p/b/n同向,单旋转
RotateRight(parent);
brother->color = color;
parent->color = black;
tmpl->color = black;
} else { // 2.3-2.兄弟有一个红色子节点(nephew),p/b/n不同向,双旋转
RotateLeft(brother);
RotateRight(parent);
tmpr->color = color;
parent->color = black;
brother->color = black;
}
}
}
}
3.5.旋转
void RotateRight(Node* cur)
{
Node* subL = cur->left;
Node* subLR = subL->right;
Node* ancestor = cur->parent;
cur->left = subLR;
if (subLR) {
subLR->parent = cur;
}
subL->right = cur;
cur->parent = subL;
subL->parent = ancestor;
if (cur == ancestor->left) {
ancestor->left = subL;
}
else {
ancestor->right = subL;
}
}
void RotateLeft(Node* cur)
{
Node* subR = cur->right;
Node* subRL = subR->left;
Node* ancestor = cur->parent;
cur->right = subRL;
if (subRL) {
subRL->parent = cur;
}
subR->left = cur;
cur->parent = subR;
subR->parent = ancestor;
if (cur == ancestor->left) {
ancestor->left = subR;
}
else {
ancestor->right = subR;
}
}
4.验证
// https://leetcode.cn/problems/sort-an-array/description/
class Solution {
public:
vector<int> sortArray(vector<int>& nums)
{
RBTree<int> a;
srand(time(0));
int k = nums.size();
while (k) {
int x = rand() % k;
a.insert(nums[x]);
swap(nums[x], nums[k - 1]);
k--;
}
return a.traverse();
}
};
int main()
{
const int N = 10000;
vector<int> a(N);
srand(time(0));
for (int i = 0; i < N; i++) {
a[i] = rand();
}
vector<int> b(a), c(a);
for (int i = N; i >= 2; i--) {
int k = rand() % i;
swap(b[k], b[i - 1]);
k = rand() % i;
swap(c[k], c[i - 1]);
}
RBTree<int> f;
for (int i = 0; i < N; i++) {
f.insert(a[i]);
if (!f.check()) {
cout << i << " Insert Tree Error !" << endl;
exit(1);
}
}
assert(f.get_count() == N);
for (int i = 0; i < N; i++) {
auto found = f.search(b[i]);
if (!found || !f.check()) {
cout << i << " Search Tree Error !" << endl;
exit(2);
}
}
for (int i = 0; i < N; i++) {
f.remove(c[i]);
if (!f.check()) {
cout << i << " Remove Tree Error !" << endl;
exit(3);
}
}
assert(f.get_count() == 0);
cout << " Good Tree !" << endl;
return 0;
}