AVL树类,其中用Node类封装了元素,左右儿子和高度,以此来作为结点:
public class AVLTree<E extends Comparable<E>> {
private Node<E> root;
public Node<E> add(E element) {
return root = insert(element, root);
}
public Node<E> delete(E element) {
return root = remove(element, root);
}
public void print() {
print(root);
}
private Node<E> insert(E element, Node<E> node) {
if (node == null) {
return new Node<>(element);
}
if (element.compareTo(node.element) < 0) {
node.left = insert(element, node.left);
} else if (element.compareTo(node.element) > 0) {
node.right = insert(element, node.right);
}
calcHeight(node);
return balance(node);
}
private Node<E> remove(E element, Node<E> node) {
if (node == null || (node.left == null && node.right == null)) {
return null;
}
if (element.compareTo(node.element) < 0) {
node.left = remove(element, node.left);
} else if (element.compareTo(node.element) > 0) {
node.right = remove(element, node.right);
} else {
if (node.right == null) {
node = node.left;
} else if (node.left == null) {
node = node.right;
} else {
Node<E> rightMin = searchMin(node.right);
node.element = rightMin.element;
node.right = remove(rightMin.element, node.right);
}
}
calcHeight(node);
return balance(node);
}
private void print(Node<E> node) {
if (node == null) {
return;
}
System.out.println(node.element + " , height = " + node.height);
print(node.left);
print(node.right);
}
static class Node<E> {
E element;
Node<E> left;
Node<E> right;
int height;
public Node(E element) {
this.element = element;
}
}
private Node<E> searchMin(Node<E> node) {
assert node != null;
if (node.left != null) {
return searchMin(node.left);
}
return node;
}
private int height(Node<E> node) {
return node == null ? -1 : node.height;
}
private void calcHeight(Node<E> node) {
node.height = Math.max(height(node.left), height(node.right)) + 1;
}
private Node<E> leftRotate(Node<E> node) {
Node<E> newNode = node.right;
node.right = newNode.left;
newNode.left = node;
calcHeight(node);
calcHeight(newNode);
return newNode;
}
private Node<E> rightRotate(Node<E> node) {
Node<E> newNode = node.left;
node.left = newNode.right;
newNode.right = node;
calcHeight(node);
calcHeight(newNode);
return newNode;
}
private Node<E> leftAndRightRotate(Node<E> node) {
node.left = leftRotate(node.left);
return rightRotate(node);
}
private Node<E> rightAndLeftRotate(Node<E> node) {
node.right = rightRotate(node.right);
return leftRotate(node);
}
private Node<E> balance(Node<E> node) {
if (height(node.left) - height(node.right) == 2) {
if (height(node.left.left) > height(node.left.right)) {
return rightRotate(node);
} else {
return leftAndRightRotate(node);
}
} else if (height(node.right) - height(node.left) == 2) {
if (height(node.right.right) > height(node.right.left)) {
return leftRotate(node);
} else {
return rightAndLeftRotate(node);
}
}
return node;
}
}
测试代码:
public class Main {
public static void main(String[] args) {
AVLTree<Integer> tree = new AVLTree<>();
tree.add(3);
tree.add(2);
tree.add(1);
tree.add(4);
tree.add(5);
tree.add(6);
tree.add(7);
tree.add(10);
tree.add(9);
tree.add(8);
tree.print();
System.out.println("=====================");
tree.delete(4);
tree.print();
}
}
测试结果: