RBTree.h的实现:
#pragma once
#include<iostream>
#include<cassert>
#include<vector>
using namespace std;
enum Colour
{
RED,
BLACK,
};
template<class T>
struct RBTreeNode {
T data;
RBTreeNode<T>* parent;
RBTreeNode<T>* right;
RBTreeNode<T>* left;
Colour col;
RBTreeNode(const T& data):data(data),parent(nullptr),right(nullptr),left(nullptr),col(RED){}
};
template<class T,class Ptr,class Ref>
struct RBTreeIterator {
typedef RBTreeNode<T> Node;
typedef RBTreeIterator<T, Ptr, Ref> Self;
Node* node;
RBTreeIterator(Node* node) :node(node) {}
Ref operator*() {
return node->data;
}
Ptr operator->() {
return &(node->data);
}
Self& operator++() {
if (node->right == nullptr)
{
Node* cur = node;
Node* parent = cur->parent;
while (parent && parent->right == cur)
{
cur = parent;
parent = cur->parent;
}
node = parent;
}
else {
Node* subLeft = node->right;
while (subLeft->left)
{
subLeft = subLeft->left;
}
node = subLeft;
}
return *this;
}
Self operator++(int) {
Self tem(*this);
++(*this);
return tem;
}
Self& operator--() {
if (node->left == nullptr)
{
Node* cur = node;
Node* parent = cur->parent;
while (parent && parent->left == cur)
{
cur = parent;
parent = cur->parent;
}
node = parent;
}
else {
Node* subRight = node->left;
while (subRight->right)
{
subRight = subRight->right;
}
node = subRight;
}
return *this;
}
Self operator--(int) {
Self tem(*this);
--(*this);
return tem;
}
bool operator==(const Self&s) const {
return s.node == node;
}
bool operator!=(const Self&s)const {
return s.node != node;
}
};
template<class K,class T,class KeyOfT>
class RBTree {
typedef RBTreeNode<T> Node;
public:
typedef RBTreeIterator<T, T*, T&>iterator;
typedef RBTreeIterator<T,const T*,const T&>const_iterator;
private:
Node* root = nullptr;
KeyOfT kot;
void RotateL(Node* parent) {
Node* subR = parent->right;
Node* subRL = subR->left;
parent->right = subRL;
if (subRL)
{
subRL->parent = parent;
}
subR->left = parent;
Node* pparent = parent->parent;
parent->parent = subR;
if (parent == root)
{
root = subR;
root->parent = nullptr;
}
else
{
if (parent == pparent->left)
{
pparent->left = subR;
}
else {
pparent->right = subR;
}
subR->parent = pparent;
}
}
void RotateR(Node* parent) {
Node* subL = parent->left;
Node* subLR = subL->right;
parent->left = subLR;
if (subLR)
{
subLR->parent = parent;
}
Node* pparent = parent->parent;
subL->right = parent;
parent->parent = subL;
if (parent == root)
{
root = subL;
root->parent = nullptr;
}
else
{
if (parent == pparent->left)
{
pparent->left = subL;
}
else {
pparent->right = subL;
}
subL->parent = pparent;
}
}
public:
iterator Begin() {
Node* subLeft = root;
while (subLeft && subLeft->left) {
subLeft = subLeft->left;
}
return iterator(subLeft);
}
iterator End() {
return iterator(nullptr);
}
const_iterator Begin() const{
Node* subLeft = root;
while (subLeft && subLeft->left) {
subLeft = subLeft->left;
}
return const_iterator(subLeft);
}
const_iterator End() const{
return const_iterator(nullptr);
}
pair<iterator,bool> Insert(const T& data) {
if (root == nullptr)
{
root = new Node(data);
root->col = BLACK;
return make_pair(iterator(root), true);
}
Node* cur = root;
Node* parent = nullptr;
while (cur)
{
if (kot(cur->data) < kot(data)) {
parent = cur;
cur = cur->right;
}
else if (kot(cur->data) > kot(data))
{
parent = cur;
cur = cur->left;
}
else
{
return make_pair(iterator(cur),true);
}
}
cur = new Node(data);
Node* newNode = cur;
cur->col = RED;
if (kot(data) < kot(parent->data))
{
parent->left = cur;
}
else {
parent->right = cur;
}
cur->parent = parent;
//父亲节点为黑色直接插入
//不能有连续的红色节点
while (parent && parent->col == RED)
{
Node* grandfather = parent->parent;
if (parent == grandfather->left)
{
Node* uncle = grandfather->right;
//uncle存在且为红色
if (uncle && uncle->col == RED)
{
parent->col = uncle->col = BLACK;
grandfather->col = RED;
cur = grandfather;
parent = cur->parent;
}
else //uncle不存在或存在且为黑色
{
if (cur == parent->left)
{
RotateR(grandfather);
parent->col = BLACK;
grandfather->col = RED;
}
else
{
RotateL(parent);
RotateR(grandfather);
cur->col = BLACK;
grandfather->col = RED;
}
//更新完局部子树后不会影响性质
break;
}
}
else {
Node* uncle = grandfather->left;
if (uncle && uncle->col == RED)
{
parent->col = uncle->col = BLACK;
grandfather->col = RED;
cur = grandfather;
parent = cur->parent;
}
else
{
if (cur == parent->right)
{
RotateL(grandfather);
parent->col = BLACK;
grandfather->col = RED;
}
else
{
RotateR(parent);
RotateL(grandfather);
cur->col = BLACK;
grandfather->col = RED;
}
break;
}
}
}
root->col = BLACK;
return make_pair(iterator(newNode),true);
}
iterator Find(const K& k) {
Node* cur = root;
KeyOfT kot;
while (cur)
{
if (kot(cur->data) > k) {
cur = cur->left;
}else if (kot(cur->data) < k)
{
cur = cur->right;
}
else
{
return iterator(cur);
}
}
return End();
}
};
set.h的实现:
#pragma once
#include"RBTreeForSetMap.h"
template<class K>
class set {
struct SetOfT {
const K& operator()(const K& k) {
return k;
}
};
public:
typedef typename RBTree<K, K, SetOfT>::const_iterator Iterator;
typedef typename RBTree<K, K, SetOfT>::const_iterator const_Iterator;
private:
RBTree<K, K, SetOfT>t;
public:
pair<Iterator,bool> insert(const K& key) {
auto ret = t.Insert(key);
return pair<Iterator,bool>(Iterator(ret.first.node),ret.second);
}
Iterator begin() const {
return t.Begin();
}
Iterator end() const {
return t.End();
}
};
void testSet1() {
set<int>s;
s.insert(4);
s.insert(5);
s.insert(1);
s.insert(3);
s.insert(2);
set<int>::Iterator it = s.begin();
while (it != s.end())
{
cout << *it << " ";
++it;
}
cout << endl;
}
map.h的实现:
#pragma once
#include"RBTreeForSetMap.h"
template<class K,class V>
class map {
struct MapOfT {
const K& operator()(const pair<K, V>& kv) {
return kv.first;
}
};
public:
typedef typename RBTree<K, pair<K, V>, MapOfT>::iterator Iterator;
typedef typename RBTree<K, pair<K, V>, MapOfT>::const_iterator const_Iterator;
private:
RBTree<K, pair<K, V>, MapOfT>t;
public:
pair<Iterator, bool> insert(const pair<K, V>& kv) {
return t.Insert(kv);
}
V& operator[](const K& k) {
auto ret = t.Insert({ k,V() });
return ret.first->second;
}
Iterator begin() {
return t.Begin();
}
Iterator end() {
return t.End();
}
const_Iterator begin() const {
return t.Begin();
}
const_Iterator end() const {
return t.End();
}
};
void testMap1() {
map<string, int>m;
m.insert({ "111",1 });
m.insert({ "222",2 });
m.insert({ "333",3 });
m.insert({ "555",5 });
m.insert({ "444",4 });
map<string,int>::Iterator it = m.begin();
while (it != m.end())
{
cout << it->first << ":" << it->second << " ";
++it;
}
cout << endl;
}
void testMap2()
{
string arr[] = { "苹果", "西瓜", "苹果", "西瓜", "苹果", "苹果", "西瓜", "苹果", "香蕉", "苹果", "香蕉" };
map<string, int> countMap;
for (auto& str : arr)
{
countMap[str]++;
}
for (const auto& kv : countMap)
{
cout << kv.first << ":" << kv.second << endl;
}
}