节点定义:
public class TreeNode<K, V> {
//关键字
public K key;
//值
public V value;
//左孩子节点
public TreeNode<K, V> left;
//右孩子节点
public TreeNode<K, V> right;
//父节点
public TreeNode<K, V> parent;
//是否是左孩子节点
public boolean isLeftChild;
//以当前节点作为根节点的高度
public int height;
//当前节点编号
public int num;
public TreeNode() {
}
public TreeNode(K key, V value, TreeNode<K, V> left, TreeNode<K, V> right, TreeNode<K, V> parent) {
this.key = key;
this.value = value;
this.left = left;
this.right = right;
this.parent = parent;
}
public boolean isLeft() {
return isLeftChild;
}
public boolean isRight() {
return !isLeftChild;
}
@Override
public String toString() {
return "Node[" + "key=" + key + ']';
}
}
接口定义:
import java.util.List;
import java.util.function.Consumer;
public interface IBinarySearchTree<K, V> {
/**
* 新增节点
*
* @param k 关键字
* @param v 值
* @return 返回插入的这个节点
*/
TreeNode<K, V> insert(K k, V v);
/**
* 中序遍历
*
* @param con 处理中序遍历的每个元素的函数
*/
void inOrder(Consumer<K> con);
/**
* 查找元素
*
* @param key 要查找的关键字
* @return 返回查找出的值
*/
V lookupValue(K key);
/**
* 获取最小关键字
*
* @return 返回最小的关键字
*/
K min();
/**
* 获取最大关键字
*
* @return 返回最大的关键字
*/
K max();
/**
* 移除关键字对应的节点
*
* @param key 关键字
*/
void remove(K key);
/**
* x的后继:比x大的第一个元素
* 1、是其右子树的最小值
* 2、没有右子树,则向上追溯,直到每个祖先节点是左孩子,
* 返回这个祖先节点的父节点,它就是x的后继
*
* @param x 关键字
* @return 返回x的后继节点
*/
K successor(K x);
/**
* x的前驱:比x大的第一个元素
*
* @param x 关键字
* @return 返回x的前驱
*/
K predecessor(K x);
/**
* 二叉树是否平衡
*
* @return 返回布尔值
*/
boolean isBalance();
/**
* 返回节点数
*
* @return
*/
int getSize();
/**
* 树的高度
*
* @return 返回树的高度
*/
int getHeight();
/**
* 层次遍历二叉树
* @return 返回层次遍历后的结果
*/
<K, V>List<List<TreeNode<K, V>>> levelOrder();
}
实现类:
import java.util.*;
import java.util.function.Consumer;
public class BinarySearchTree<K, V> implements IBinarySearchTree<K,V> {
private int size = 0;
private TreeNode<K, V> root;
private Comparator comparator;
public BinarySearchTree() {
}
public BinarySearchTree(TreeNode<K, V> root) {
this.root = root;
this.root.num = 1;
++size;
}
/**
* 新增节点
*
* @param key 关键字
* @param value 值
* @return 返回插入的这个节点
*/
@Override
public TreeNode<K, V> insert(K key, V value) {
if (!(key instanceof Comparable)) {
throw new ClassCastException();
}
TreeNode<K, V> parent = null;
TreeNode<K, V> curr = root;
while (curr != null) {
parent = curr;
if (compare(key, curr.key) < 0) {
curr = curr.left;
} else if (compare(key, curr.key) > 0) {
curr = curr.right;
} else {
curr.value = value;
return curr;
}
}
curr = new TreeNode<>(key, value, null, null, null);
curr.parent = parent;
if (parent == null) {
root = curr;
root.num = 1;
} else if (compare(key, parent.key) < 0) {
parent.left = curr;
curr.isLeftChild = true;
curr.num = 2 * parent.num;
} else {
parent.right = curr;
curr.isLeftChild = false;
curr.num = 2 * parent.num + 1;
}
++size;
updateHeight(curr);
return curr;
}
/**
* 中序遍历
*
* @param con 处理中序遍历的每个元素的函数
*/
@Override
public void inOrder(Consumer<K> con) {
if (root != null) {
inOrder(root, con);
}
}
private void inOrder(TreeNode<K, V> p, Consumer<K> con) {
if (p != null) {
inOrder(p.left, con);
con.accept(p.key);
inOrder(p.right, con);
}
}
/**
* 查找元素
*
* @param key 要查找的关键字
* @return 返回查找出的值
*/
@Override
public V lookupValue(K key) {
TreeNode<K, V> lookupNode = lookupNode(key);
return lookupNode == null ? null : lookupNode.value;
}
private TreeNode<K, V> lookupNode(K key) {
TreeNode<K, V> p = root;
while (p != null && compare(key, p.key) != 0) {
if (compare(key, p.key) < 0) {
p = p.left;
} else {
p = p.right;
}
}
return p;
}
/**
* 获取最小关键字
*
* @return 返回最小的关键字
*/
@Override
public K min() {
return minNode(root).key;
}
private TreeNode<K, V> minNode(TreeNode<K, V> p) {
while (p.left != null) {
p = p.left;
}
return p;
}
/**
* 获取最大关键字
*
* @return 返回最大的关键字
*/
@Override
public K max() {
return maxNode(root).key;
}
private TreeNode<K, V> maxNode(TreeNode<K, V> p) {
while (p.right != null) {
p = p.right;
}
return p;
}
/**
* 移除关键字对应的节点
*
* @param key 关键字
*/
@Override
public void remove(K key) {
removeNode(lookupNode(key));
--size;
}
private void removeNode(TreeNode<K, V> x) {
if (x != null) {
//如果x是叶子节点
if (x.left == null && x.right == null) {
if (x.parent == null) {
root = null;
x = null;
return;
}
if (x.isLeft()) {
x.parent.left = null;
} else {
x.parent.right = null;
}
x.parent = null;
x = null;
} else if (x.left == null) {//有子节点,但左子为空,有右孩子
if (x.isLeft()) {
x.parent.left = x.right;
x.right.isLeftChild = true;
x.right.parent = x.parent;
} else {
if (x.parent != null) {
x.parent.right = x.right;
x.right.isLeftChild = false;
x.right.parent = x.parent;
} else {
//x是根节点
root = x.right;
}
}
x = null;
} else if (x.right == null) {//有子节点,但右子为空,有左孩子
if (x.isLeft()) {
x.parent.left = x.left;
x.left.isLeftChild = true;
x.left.parent = x.parent;
} else {
if (x.parent != null) {
x.parent.right = x.left;
x.left.isLeftChild = false;
x.left.parent = x.parent;
} else {
//x是根节点
root = x.left;
}
}
x = null;
} else {//左右孩子都不为空
//得到x右子树中最小的节点
TreeNode<K, V> minOfRight = minNode(x.right);
//用x右子树中最小节点的key更换x的内容
x.key = minOfRight.key;
//删掉x右子树中最小的节点
removeNode(minOfRight);
}
}
}
/**
* x的后继:比x大的第一个元素
* 1、是其右子树的最小值
* 2、没有右子树,则向上追溯,直到每个祖先节点是左孩子,
* 返回这个祖先节点的父节点,它就是x的后继
*
* @param x 关键字
* @return 返回x的后继节点
*/
@Override
public K successor(K x) {
TreeNode<K, V> xNode = lookupNode(x);
if (xNode == null) {
return null;
}
if (xNode.right != null) {
return minNode(xNode.right).key;
}
TreeNode<K, V> yNode = xNode.parent;
while (yNode != null && yNode.right == xNode) {
xNode = yNode;
yNode = yNode.parent;
}
return yNode == null ? null : yNode.key;
}
/**
* x的前驱:比x大的第一个元素
*
* @param x 关键字
* @return 返回x的前驱
*/
@Override
public K predecessor(K x) {
TreeNode<K, V> xNode = lookupNode(x);
if (xNode == null) {
return null;
}
if (xNode.left != null) {
return maxNode(xNode.left).key;
}
TreeNode<K, V> yNode = xNode.parent;
while (yNode != null && yNode.left == xNode) {
xNode = yNode;
yNode = yNode.parent;
}
return yNode == null ? null : yNode.key;
}
/**
* 二叉树是否平衡
*
* @return 返回布尔值
*/
@Override
public boolean isBalance() {
return false;
}
/**
* 返回节点数
*
* @return
*/
@Override
public int getSize() {
return size;
}
/**
* 树的高度
*
* @return 返回树的高度
*/
@Override
public int getHeight() {
return getHeight(root);
}
protected int getHeight(TreeNode node) {
if (node == null) return 0;
int l = getHeight(node.left);
int r = getHeight(node.right);
return 1 + Math.max(l, r);
}
/**
* 层次遍历二叉树
*
* @return 返回层次遍历后的结果
*/
@Override
public <K, V> List<List<TreeNode<K, V>>> levelOrder() {
List<List<TreeNode<K, V>>> res = new ArrayList<>();
Queue<TreeNode<K, V>> queue = new LinkedList<>();
TreeNode<K, V> p = (TreeNode<K, V>) root;
p.num = 1;
queue.add(p);
TreeNode<K, V> last = p;
TreeNode<K, V> nlast = null;
List<TreeNode<K, V>> list = new ArrayList<>();
while (!queue.isEmpty()) {
TreeNode<K, V> peek = queue.peek();
if (peek.left != null) {
// peek.left.num = 2 * peek.num;
queue.add(peek.left);
nlast = peek.left;
}
if (peek.right != null) {
// peek.right.num = 2 * peek.num + 1;
queue.add(peek.right);
nlast = peek.right;
}
list.add(queue.poll());
if (last == peek) {
res.add(list);
list = new ArrayList<>();
last = nlast;
}
}
return res;
}
@SuppressWarnings({"unchecked", "rawtypes"})
private int compare(K key1, K key2) {
if (null == comparator) {
return ((Comparable) key1).compareTo((Comparable) key2);
} else {
return comparator.compare(key1, key2);
}
}
private void updateHeight(TreeNode<K, V> curr) {
if (curr.parent == null) return;//util root
TreeNode<K, V> p = curr.parent;
if (p.height == curr.height) {
p.height++;
updateHeight(p);//递归
}
}
}
测试类:
import org.junit.Before;
import org.junit.Test;
import java.util.List;
import static org.assertj.core.api.Assertions.assertThat;
public class BinarySearchTreeTest {
BinarySearchTree<Integer, String> tree = new BinarySearchTree<Integer, String>();
@Before
public void insert() {
System.out.println("height:" + tree.getHeight());
tree.insert(4, null);
System.out.println("height:" + tree.getHeight());
tree.insert(1, null);
System.out.println("height:" + tree.getHeight());
tree.insert(10, null);
System.out.println("height:" + tree.getHeight());
tree.insert(14, null);
System.out.println("height:" + tree.getHeight());
tree.insert(7, "嘿嘿嘿");
System.out.println("height:" + tree.getHeight());
tree.insert(16, null);
System.out.println("height:" + tree.getHeight());
tree.insert(9, null);
System.out.println("height:" + tree.getHeight());
tree.insert(3, null);
System.out.println("height:" + tree.getHeight());
tree.insert(5, null);
System.out.println("height:" + tree.getHeight());
tree.insert(2, null);
System.out.println("height:" + tree.getHeight());
tree.insert(20, null);
tree.insert(25, null);
System.out.println("height:" + tree.getHeight());
System.out.println("size:" + tree.getSize());
}
@Test
public void inOrder() {
tree.inOrder(k ->{
System.out.println(k);
});
}
@Test
public void lookupValue() {
System.out.println(tree.lookupValue(7));
assertThat(tree.lookupValue(7)).isEqualTo("嘿嘿嘿");
}
@Test
public void min() {
System.out.println(tree.min());
assertThat(tree.min()).isEqualTo(1);
}
@Test
public void max() {
System.out.println(tree.max());
assertThat(tree.max()).isEqualTo(25);
}
@Test
public void remove() {
tree.remove(4);
System.out.println("删除4:");
tree.inOrder(obj->{
System.out.print(obj + "\t");
});
System.out.println();
tree.remove(9);
System.out.println("删除9:");
tree.inOrder(obj->{
System.out.print(obj + "\t");
});
System.out.println();
tree.remove(16);
System.out.println("删除16:");
tree.inOrder(obj->{
System.out.print(obj + "\t");
});
}
@Test
public void levelOrder() {
List<List<TreeNode<Integer, String>>> lists = tree.levelOrder();
for (List<TreeNode<Integer, String>> list : lists) {
for (TreeNode<Integer, String> element : list) {
System.out.print(element.num + ":" + element + "\t\t");
}
System.out.println();
}
}
}