二叉查找树
代码实现
使用JAVA实现的二叉查找树,支持以下功能:
- 插入元素(容器可含有重复的数)
- 删除元素(对于相同的元素,只删除一个)
- 查询元素的排名(val的排名定义为比val小的数的个数+1)
- 根据排名查询元素
- 求某个元素的前驱(val的前驱为小于val,且最大的数)
- 求某个元素的后继(val的后继为大于val,且最小的数)
- 获取容器的尺寸(即容器中的元素数量)
- 判断容器是否包含某个元素
- 清空容器
package BalancedTree.BST;
import java.util.Stack;
public class BST implements Tree {
static class Node {
int val; // 结点中存储的值
Node leftChild, rightChild; // 左孩子和右孩子
int cnt; // 该结点所在子树中所有结点的数量
Node(int val) {
this.val = val;
leftChild = rightChild = null;
cnt = 1;
}
// 更新结点的信息,当该结点的子树发生变化后使用
void update() {
cnt = 1;
cnt += leftChild != null ? leftChild.cnt : 0;
cnt += rightChild != null ? rightChild.cnt : 0;
}
}
private Node root = null;
// 计算结点node在以node为树根的子树中的排名
private int nodeRank(Node node) {
if (node.leftChild != null) {
return node.leftChild.cnt + 1;
} else {
return 1;
}
}
// 插入结点
public void insert(int val) {
if (root == null) {
root = new Node(val);
} else {
Node cur = root;
while (true) {
++cur.cnt;
if (val < cur.val) {
if (cur.leftChild == null) {
cur.leftChild = new Node(val);
break;
} else {
cur = cur.leftChild;
}
} else {
if (cur.rightChild == null) {
cur.rightChild = new Node(val);
break;
} else {
cur = cur.rightChild;
}
}
}
}
}
// 删除值为val的结点
// 如果存在多个相同的数,只删除一个,返回true
// 如果不存在,返回false
public boolean remove(int val) {
if (root == null)
return false;
Node now = root;
Node father = null;
boolean isFatherLeft = true; // 记录now结点是father的左孩子还是右孩子
Stack<Node> toUpdate = new Stack<>(); // 记录在删除操作后需要更新的结点
while (now != null) {
toUpdate.push(now);
if (val < now.val) {
father = now;
isFatherLeft = true;
now = now.leftChild;
} else if (val > now.val) {
father = now;
isFatherLeft = false;
now = now.rightChild;
} else {
toUpdate.pop();
if (now.leftChild == null && now.rightChild == null) {
// 情况一:待删除结点没有左右孩子,直接删除即可
if (father == null) {
root = null;
} else if (isFatherLeft) {
father.leftChild = null;
} else {
father.rightChild = null;
}
} else if (now.leftChild != null && now.rightChild == null) {
// 情况二:待删除结点只有左孩子,则将该左孩子连接到父结点上即可
if (father == null) {
root = now.leftChild;
} else if (isFatherLeft) {
father.leftChild = now.leftChild;
} else {
father.rightChild = now.leftChild;
}
} else if (now.leftChild == null && now.rightChild != null) {
// 情况三:待删除结点只有右孩子,则将该右孩子连接到父结点上即可
if (father == null) {
root = now.rightChild;
} else if (isFatherLeft) {
father.leftChild = now.rightChild;
} else {
father.rightChild = now.rightChild;
}
} else {
// 情况四:待删除结点有左右孩子,则:
// 1. 找到该结点的右结点x
// 2. 一直向左,找到x最左边的左结点y
// 3. 将待删除结点的值改为y的值
// 4. 按情况三的思路删除y
toUpdate.push(now);
Node cur = now.rightChild;
Node _father = null;
while (cur.leftChild != null) {
toUpdate.push(cur);
_father = cur;
cur = cur.leftChild;
}
now.val = cur.val;
if (_father != null) {
_father.leftChild = cur.rightChild;
} else {
now.rightChild = cur.rightChild;
}
}
// 删除操作完毕后,更新祖先结点
while (!toUpdate.empty()) {
toUpdate.pop().update();
}
return true;
}
}
return false;
}
// 清空容器
public void clear() {
root = null;
}
// 获取结点总数
public int size() {
if (root == null) {
return 0;
} else {
return root.cnt;
}
}
// 获取val的排名
// val的排名定义为比val小的数的个数+1
public int rank(int val) {
Node cur = root;
int _rank = 1;
while (cur != null) {
if (val <= cur.val) {
cur = cur.leftChild;
} else {
_rank += nodeRank(cur);
cur = cur.rightChild;
}
}
return _rank;
}
// 获取排名为k的元素
// 返回值为-1表示元素不存在
public int kth(int rk) {
Node cur = root;
while (cur != null) {
int _rank = nodeRank(cur);
if (rk < _rank) {
cur = cur.leftChild;
} else if (rk > _rank) {
rk -= _rank;
cur = cur.rightChild;
} else {
return cur.val;
}
}
return -1;
}
// 检测容器是否包含某个值
public boolean contains(int val) {
Node cur = root;
while (cur != null) {
if (val < cur.val) {
cur = cur.leftChild;
} else if (val > cur.val) {
cur = cur.rightChild;
} else {
return true;
}
}
return false;
}
// 返回val的前驱
// val的前驱定义为小于val且最大的数
public int lower(int val) {
return kth(rank(val) - 1);
}
// 返回val的后继
// val的后继定义为大于val且最小的数
public int upper(int val) {
return kth(rank(val + 1));
}
}
功能验证
为了验证以上代码的正确性,笔者还编写了一个验题程序,能够产生随机数,对容器进行随机操作,并校验结果的正确性。
package BalancedTree.BST;
import java.util.*;
public class Test {
static final int ROUND = 200; // 测试轮数
static final int UPPER_BOUND = 200; // 往容器中添加的元素的最大值
static final int MAX_SIZE = 500; // 当容器尺寸超过阈值时,结束测试
static final int QUERY_TIME = 10; // 查询排名的次数
List<Integer> list;
Tree tree;
Random random;
Test(Tree tree) {
this.tree = tree;
list = new LinkedList<>();
random = new Random();
}
void list_insert(int val) {
ListIterator<Integer> iter = list.listIterator();
while (iter.hasNext()) {
int _val = iter.next();
if (val <= _val) {
iter.previous();
break;
}
}
iter.add(val);
}
int list_rank(int val) {
ListIterator<Integer> iter = list.listIterator();
int rank = 1;
while (iter.hasNext()) {
int _val = iter.next();
if (_val >= val) {
break;
}
++rank;
}
return rank;
}
void test_insert() {
int val = random.nextInt(UPPER_BOUND);
list_insert(val);
tree.insert(val);
test_size();
}
void test_size() {
if (tree.size() != list.size()) {
throw new Error("两个容器的size不一致!");
}
}
void test_remove() {
Integer val = random.nextInt(UPPER_BOUND);
if (list.remove(val) != tree.remove(val)) {
throw new Error("remove()的实现有问题!");
}
test_size();
}
void test_rank() {
for (int i = 0; i < QUERY_TIME; ++i) {
int val = random.nextInt(UPPER_BOUND);
if (list_rank(val) != tree.rank(val)) {
throw new Error("rank()的实现有问题!");
}
}
}
void test_kth() {
for (int i = 0; i < QUERY_TIME; ++i) {
int size = list.size();
int rank = random.nextInt(size * 3 + 1) - size;
int val;
if (rank >= 1 && rank <= list.size()) {
val = list.get(rank - 1);
} else {
val = -1;
}
if (tree.kth(rank) != val) {
throw new Error("kth()的实现有问题!");
}
}
}
void test_contains() {
for (int i = 0; i < QUERY_TIME; ++i) {
int val = random.nextInt(UPPER_BOUND);
if (tree.contains(val) != list.contains(val)) {
throw new Error("contains()的实现有问题!");
}
}
}
void test_lower() {
for (int i = 0; i < QUERY_TIME; ++i) {
int val = random.nextInt(UPPER_BOUND);
Iterator<Integer> iter = ((LinkedList<Integer>) list).descendingIterator();
int lower = -1;
while (iter.hasNext()) {
int _val = iter.next();
if (_val < val) {
lower = _val;
break;
}
}
if (tree.lower(val) != lower) {
throw new Error("lower()的实现有问题!");
}
}
}
void test_upper() {
for (int i = 0; i < QUERY_TIME; ++i) {
int val = random.nextInt(UPPER_BOUND);
int upper = -1;
for (int num : list) {
if (num > val) {
upper = num;
break;
}
}
if (tree.upper(val) != upper) {
throw new Error("upper()的实现有问题!");
}
}
}
void test() {
Random random = new Random();
for (int round = 0; round < ROUND; ++round) {
list.clear();
tree.clear();
if (list.size() != tree.size()) {
throw new Error("clear()的实现有问题!");
}
while (list.size() <= MAX_SIZE) {
int op = random.nextInt(8) + 1;
if (op == 1) {
test_insert();
} else if (op == 2) {
test_remove();
} else if (op == 3) {
test_rank();
} else if (op == 4) {
test_kth();
} else if (op == 5) {
test_upper();
} else if (op == 6) {
test_lower();
} else if (op == 7) {
test_size();
} else if (op == 8) {
test_contains();
} else {
throw new Error("op设置有误!");
}
}
test_size();
}
System.out.println("测试完毕,没有发现错误!");
}
public static void main(String[] args) {
new Test(new ScapegoatTree()).test();
}
}