二叉搜索树将链表插入的灵活性和数组的二分查找的高效性结合起来的符号表。
这棵树中的节点的左子树的值都小于该节点, 右子树都大于该节点
。
二叉搜索树和快速排序很像, 树中当前节点等价于快速排序中的基准数, 左子树都小于该节点, 右子树都大于该节点。
插入操作:
先判断在二叉搜索树中是否已经存在即将要插入的键, 若存在, 直接把值更新。
若不存在, 在当前节点中新建节点并插入。
void put(KEY k, VAL v){root = put(root, k, v);}
node<KEY, VAL> *put(node<KEY, VAL> *t, KEY k, VAL v){
if(t == NULL) return new node<KEY, VAL>(k, v, 1);
if(t->key < k) t->right = put(t->right, k, v);
else if(t->key > k) t->left = put(t->left, k, v);
else t->val = v;
t->N = size(t->left) + (t->right) + 1;
return t;
}
查找操作:
若当前的键小于要查找的键, 则去查找当前节点的右子树, 若当前键大于要查找的键, 则去查找当前节点的左子树, 若相等, 则返回该节点的值。
VAL get(KEY k){return get(root, k);}
VAL get(node<KEY, VAL> *t, KEY k){
if(t == NULL) return NULL;
if(t->key < k) return get(t->right, k);
else if(t->key > k) return get(t->left, k);
else return t->val;
}
最大键:
若当前节点存在右节点, 则递归,把右节点当做当前节点, 若不存在, 则返回当前节点。
KEY max(){return max(root);}
KEY max(node<KEY, VAL> *t){
if(t->right == NULL) return t->key;
else return max(t->right);
}
向上取整(返回小于等于当前节点的最大键):
若给定的键小于当前节点的键, 则小于等于当前节点的最大键肯定在当前节点的左子树中,
若给定的键大于当前节点的键, 则可能存在当前节点的右子树中, 若右子树存在此键就返回, 否则返回右子树的父节点。
KEY floor(KEY k){ return floor(root, k);}
KEY floor(node<KEY, VAL> *t, KEY k){
if(t == NULL) return NULL;
if(t->key == k) return t->key;
if(t->key > k) return floor(t->left, k);
node<KEY, VAL> *x = floor(t->right, k);
if(x == NULL) return t;
else return x;
}
返回排名为k的键(从0开始):
求当前节点的左子树中节点的个数, 若个数大于k,则继续递归到当前节点的左子树的节点上, 若个数小于k,则递归到右子树, 并把k改为 k-t-1, 减去左子树节点的个数和自己。
KEY select(int k){return select(root, k)->key;}
node<KEY, VAL> *select(node<KEY, VAL) *t, int k){
if(t == NULL) return NULL;
int cnt = size(t->left);
if(cnt > k) return select(t->left, k);
else if(cnt < k) return select(t->right, k-cnt-1);
else return t;
}
返回小于k的键的数量:
如果当前节点的键大于k,则去当前节点的左子树中继续寻找
若当前节点的键小于k, 则数量等于 1 + 左子树中节点的数目 + 右子树中键小于k的数量。
int rank(KEY k){return rank(root, k);}
int rank(node<KEY, VAL> *t, KEY k){
if(t == NULL) return 0;
if(k < t->key) return rank(t->left);
else if(t > t->key) return 1 + size(t->left) + rank(t->right);
else return size(t->left);
}
删除最小元素:
因为BST的性质, 最小元素肯定在根节点的左子树中, 所以一直遍历左子树, 直到没有左节点, 则当前节点就是最小元素, 把当前节点的右子树给当前节点的父节点, 并更新节点数N。
void delMin(){root = delMin(root);}
node<KEY, VAL> *delMin(node<KEY, VAL> *t){
if(t->left == NULL) return t->right;
t->left = delMin(t->left);
t->N = size(t->left) + size(t->right) + 1;
return t;
}
删除指定元素:
先找到该元素,
若该节点既有右节点又有左节点, 则把要删除节点的右子树的最小键替换要删除的元素。
若有左节点或者有右节点, 则把节点给要删除节点的父节点。
若没有左右节点, 则直接删除。
void dele(KEY k){root = dele(root, k);}
node<KEY, VAL> *dele(node<KEY, VAL> *t, KEY k){
if(t == NULL) return NULL;
if(t->key > k) return dele(t->left, k);
else if(t->key < k) return dele(t->right, k);
else{
if(t->right == NULL) return t->left;
if(t->left == NULL) return t->right;
node<KEY, VAL> *x = t;
t = min(t->right);
t->left = x->left;
t->right = delMin(t->right);
}
t->N = size(t->left) + size(t->right) + 1;
return t;
}
完整代码
#include<iostream>
using namespace std;
template<class KEY, class VAL>
class node{
public:
KEY key;
VAL val;
int N;
node<KEY, VAL> *left, *right;
node(){}
node(KEY k, VAL v, int n){key = k; val = v; N = n; left = right = NULL;}
};
template<class KEY, class VAL>
class BST{
node<KEY, VAL> *root;
int size(node<KEY, VAL> *t){
if(t == NULL) return 0;
return t->N;
}
node<KEY, VAL> *put(node<KEY, VAL> *t, KEY k, VAL v){
if(t == NULL) return new node<KEY, VAL>(k, v, 1);
if(t->key > k) t->left = put(t->left, k, v);
else if(t->key < k) t->right = put(t->right, k, v);
else t->val = v;
t->N = size(t->left) + size(t->right) + 1;
return t;
}
VAL get(node<KEY, VAL> *t, KEY k){
if(t == NULL) return NULL;
if(t->key > k) return get(t->left, k);
else if(t->key < k) return get(t->right, k);
else return t->val;
}
node<KEY, VAL> *max(node<KEY, VAL> *t){
if(t->right == NULL) return t;
return max(t->right);
}
node<KEY, VAL> *min(node<KEY, VAL> *t){
if(t->left == NULL) return t;
return min(t->left);
}
node<KEY, VAL> *floor(node<KEY, VAL> *t, KEY k){
if(t == NULL) return NULL;
if(t->key > k) return floor(t->left, k);
else if(t->key == k) return t;
else{
node<KEY, VAL> *x = floor(t->right, k);
if(x == NULL) return t;
return x;
}
}
node<KEY, VAL> *select(node<KEY, VAL> *t, int n){
if(t == NULL) return NULL;
int cnt = size(t->left);
if(cnt > n) return select(t->left, n);
else if(cnt < n) return select(t->right, n-cnt-1);
else return t;
}
int rank(node<KEY, VAL> *t, KEY k){
if(t == NULL) return 0;
if(t->key > k) return rank(t->left, k);
else if(t->key < k) return rank(t->right, k) + size(t->left) + 1;
else return size(t->left);
}
node<KEY, VAL> *delMin(node<KEY, VAL> *t){
if(t->left == NULL) return t->right;
t->left = delMin(t->left);
t->N = size(t->left) + size(t->right) + 1;
return t;
}
node<KEY, VAL> *dele(node<KEY, VAL> *t, KEY k){
if(t == NULL) return NULL;
if(t->key > k) return dele(t->left, k);
else if(t->key < k) return dele(t->right, k);
else{
if(t->right == NULL) return t->left;
if(t->left == NULL) return t->right;
node<KEY, VAL> *x = t;
t = min(t->right);
t->left = x->left;
t->right = delMin(t->right);
}
t->N = size(t->left) + size(t->right) + 1;
return t;
}
public:
BST(){root = new node<KEY, VAL>(); root = NULL;}
int size(){return size(root);}
void put(KEY k, VAL v){root = put(root, k, v);}
VAL get(KEY k){return get(root, k);}
KEY max(){return max(root)->key;}
KEY min(){return min(root)->key;}
KEY floor(KEY k){
node<KEY, VAL> *x = floor(root, k);
if(x == NULL) return 0;
return x->key;
}
KEY select(int n){return select(root, n)->key;}
int rank(KEY k){return rank(root, k);}
void delMin(){root = delMin(root);}
void dele(KEY k){root = dele(root, k);}
};
int main(){
BST<int, double> bst;
bst.put(3, 2);
bst.put(2, 3);
bst.put(1, 4);
bst.dele(2);
cout << bst.get(2);
return 0;
}