主要包括添加、删除、遍历等功能
#define _CRK_SECURE_NO_WARNINGS
#include <iostream>
#include <fstream>
#include <sstream>
#include <vector>
using namespace std;
int max(int a, int b) {
return a > b ? a : b;
}
template<class K, class V>
class Node
{
public:
K key;
V value;
int height;
Node* left;
Node* right;
Node() {
this->height = 0;
this->left = NULL;
this->right = NULL;
}
Node(K key, V value) {
this->key = key;
this->value = value;
this->height = 1;
this->left = NULL;
this->right = NULL;
}
};
template<class K, class V>
class BST {
public:
BST() {
root = NULL;
size = 0;
}
int Size() {
return size;
}
void add(K key, V value) {
root = add(root, key, value);
}
bool Contains(K key) {
return GetNode(root, key) != NULL;
}
Node<K, V>* GetNode(Node<K, V>* node, K key) {
if (node == NULL) {
return NULL;
}
if (key == node->key) {
return node;
}
else if (key < node->key) {
return GetNode(node->left, key);
}
else {
return GetNode(node->right, key);
}
}
V Get(K key) {
Node<K, V>* node = GetNode(root, key);
return node == NULL ? NULL : node->value;
}
void set(K key, V value) {
Node<K, V>* node = GetNode(root, key);
if (node == NULL) {
return;
}
node->value = value;
}
Node<K, V>* minimum(Node<K, V>* node) {
if (node->left == NULL) {
return node;
}
return minimum(node->left);
}
private:
//node为根
Node<K, V>* add(Node<K, V>* node, K e, V value) {
if (node == NULL) {
size++;
return new Node<K, V>(e, value);
}
if (e < node->key) {
node->left = add(node->left, e, value);
}
else if (e > node->key) {
node->right = add(node->right, e, value);
}
else {
node->value = value;
}
return node;
}
private:
Node<K, V>* root;
int size;
};
template<class K, class V>
class AVLT {
public:
AVLT() {
root = NULL;
size = 0;
}
int Size() {
return size;
}
int GetHeight(Node<K, V>* node) {
if (node == NULL) {
return 0;
}
return node->height;
}
int Height() {
return GetHeight(root);
}
int GetBalanceFactor(Node<K, V>* node) {
if (node == NULL) {
return 0;
}
return GetHeight(node->left) - GetHeight(node->right);
}
bool IsBST() {
vector<K> keys;
InOrder(root, keys);
for (int i = 1; i < (int)keys.size(); i++) {
if (keys[i - 1] > keys[i]) {
return false;
}
}
return true;
}
void InOrder(Node<K, V>* node, vector<K> keys) {
if (node == NULL) {
return;
}
InOrder(node->left, keys);
keys.push_back(node->key);
InOrder(node->right, keys);
}
bool IsBalanced() {
return IsBalanced(root);
}
bool IsBalanced(Node<K, V>* node) {
if (node == NULL) {
return true;
}
int bf = GetBalanceFactor(node);
if (abs(bf) > 1) {
return false;
}
return IsBalanced(node->left) && IsBalanced(node->right);
}
void add(K key, V value) {
root = add(root, key, value);
}
bool Contains(K key) {
return GetNode(root, key) != NULL;
}
Node<K, V>* GetNode(Node<K, V>* node, K key) {
if (node == NULL) {
return NULL;
}
if (key == node->key) {
return node;
}
else if (key < node->key) {
return GetNode(node->left, key);
}
else {
return GetNode(node->right, key);
}
}
V Get(K key) {
Node<K, V>* node = GetNode(root, key);
return node == NULL ? NULL : node->value;
}
void set(K key, V value) {
Node<K, V>* node = GetNode(root, key);
if (node == NULL) {
return;
}
node->value = value;
}
Node<K, V>* minimum(Node<K, V>* node) {
if (node->left == NULL) {
return node;
}
return minimum(node->left);
}
K minimumKey() {
return minimum(root)->key;
}
V Remove(K key) {
Node<K, V>* node = GetNode(root, key);
if (node != NULL) {
root = Remove(root, key);
return node->value;
}
return NULL;
}
void preOrder()
{
preOrder(root);
}
void inOrder()
{
inOrder(root);
}
void postOrder()
{
postOrder(root);
}
private:
void preOrder(Node<K, V>*node) const
{
if (node != NULL)
{
cout << node->key << "," << node->value << ";";
preOrder(node->left);
preOrder(node->right);
}
}
void inOrder(Node<K, V>*node) const
{
if (node != NULL)
{
inOrder(node->left);
cout << node->key << "," << node->value << ";";
inOrder(node->right);
}
}
void postOrder(Node<K, V>*node) const
{
if (node != NULL)
{
postOrder(node->left);
postOrder(node->right);
cout << node->key << "," << node->value << ";";
}
}
Node<K, V>* Remove(Node<K, V>* node, K key) {
if (node == NULL) {
return NULL;
}
Node<K, V>* ret;
if (key < node->key) {
node->left = Remove(node->left, key);
ret = node;
}
else if (key > node->key) {
node->right = Remove(node->right, key);
ret = node;
}
else {
if (node->left == NULL) {
Node<K, V>* rn = node->right;
node->right = NULL;
size--;
ret = rn;
}
else if (node->right == NULL) {
Node<K, V>* ln = node->left;
node->left = NULL;
size--;
ret = ln;
}
else
{
Node<K, V>* newroot = minimum(node->right);
newroot->right = Remove(node->right, newroot->key);
newroot->left = node->left;
node->left = node->right = NULL;
ret = newroot;
}
}
if (ret == NULL) {
return NULL;
}
ret->height = 1 + max(GetHeight(ret->left), GetHeight(ret->right));
int balancefactor = GetBalanceFactor(ret);
//左边较大,LL右旋转
if (balancefactor > 1 && GetBalanceFactor(ret->left) >= 0) {
return RightRotate(ret);
}
//左边较大,RR右旋转
if (balancefactor < -1 && GetBalanceFactor(ret->right) <= 0) {
return LeftRotate(ret);
}
//右边较大,LR左旋转,右旋转
if (balancefactor > 1 && GetBalanceFactor(ret->left) < 0) {
ret->left = LeftRotate(ret->left);
return RightRotate(ret);
}
//右边较大,RL右旋转,左旋转
if (balancefactor < -1 && GetBalanceFactor(ret->right) > 0) {
ret->right = RightRotate(ret->right);
return LeftRotate(ret);
}
balancefactor = GetBalanceFactor(node);
if (abs(balancefactor) > 1) {
cout << balancefactor << endl;
}
return ret;
}
//node为根
Node<K, V>* add(Node<K, V>* node, K e, V value) {
if (node == NULL) {
size++;
return new Node<K, V>(e, value);
}
else if (e < node->key) {
node->left = add(node->left, e, value);
}
else if (e > node->key) {
node->right = add(node->right, e, value);
}
else {
node->value = value;
}
node->height = 1 + max(GetHeight(node->left), GetHeight(node->right));
int balancefactor = GetBalanceFactor(node);
//左边较大,LL右旋转
if (balancefactor > 1 && GetBalanceFactor(node->left) >= 0) {
return RightRotate(node);
}
//左边较大,RR右旋转
if (balancefactor < -1 && GetBalanceFactor(node->right) <= 0) {
return LeftRotate(node);
}
//右边较大,LR左旋转,右旋转
if (balancefactor > 1 && GetBalanceFactor(node->left) < 0) {
node->left = LeftRotate(node->left);
return RightRotate(node);
}
//右边较大,RL右旋转,左旋转
if (balancefactor < -1 && GetBalanceFactor(node->right) > 0) {
node->right = RightRotate(node->right);
return LeftRotate(node);
}
balancefactor = GetBalanceFactor(node);
if (abs(balancefactor) > 1) {
cout << balancefactor << endl;
}
return node;
}
Node<K, V>* RightRotate(Node<K, V>* y) {
Node<K, V>* x = y->left;
//右旋转
y->left = x->right;
x->right = y;
//更新高度
y->height = 1 + max(GetHeight(y->left), GetHeight(y->right));
x->height = 1 + max(GetHeight(x->left), y->height);
return x;
}
Node<K, V>* LeftRotate(Node<K, V>* y) {
Node<K, V>* x = y->right;
//左旋转
y->right = x->left;
x->left = y;
//更新高度
y->height = 1 + max(GetHeight(y->left), GetHeight(y->right));
x->height = 1 + max(y->height, GetHeight(x->right));
return x;
}
private:
Node<K, V>* root;
int size;
};
int main(void)
{
ifstream infile("D:\\Desktop\\Clianxi\\C++\\trie\\testfile\\1.Harry Potter and the Sorcerer's Stone.txt");
if (!infile.is_open()) {
return -1;
}
vector<string> words;
BST<string, int> bst;
AVLT<string, int> avlt;
while (!infile.eof())
{
string line, tmp;
while (infile >> tmp) {
tmp.erase(0, tmp.find_first_not_of(" "));
words.push_back(tmp);
if (bst.Contains(tmp)) {
bst.set(tmp, bst.Get(tmp) + 1);
}
else
bst.add(tmp, 1);
if (avlt.Contains(tmp)) {
avlt.set(tmp, avlt.Get(tmp) + 1);
}
else
avlt.add(tmp, 1);
}
}
infile.close();
cout << "不同单词" << bst.Size() << endl;
cout << "___________" << endl;
cout << "不同单词" << avlt.Size() << endl;
cout << "___________" << endl;
cout << "isBST " << avlt.IsBST() << endl;
cout << "isBanlanced " << avlt.IsBalanced() << endl;
int arr[] = { 3,2,1,4,5,6,7,16,15,14,13,12,11,10,8,9 };
int i, len;
AVLT<int, int> *tree = new AVLT<int, int>();
cout << "依次添加:";
len = sizeof(arr) / sizeof(arr[0]);
for (i = 0; i < len; ++i)
{
cout << arr[i] << " ";
tree->add(arr[i],0);
}
cout << "\n前序遍历:";
tree->preOrder();
cout << "\n中序遍历:";
tree->inOrder();
cout << "\n后序遍历:";
tree->postOrder();
cout << "\n高度: " << tree->Height() << endl;
cout << "最小值: " << tree->minimumKey()<< endl;
for (i = 0; i < len; ++i)
{
cout << "\n删除节点:" << i;
tree->Remove(i);
cout << "\n高度: " << tree->Height() << endl;
cout << "中序遍历:";
tree->inOrder();
}
system("pause");
return 0;
}
主程序测试最后部分参考了https://blog.csdn.net/codernim/article/details/54744619