首先定义AVL树的Node节点属性。
public class Node<T> {
T val; // 值
int height = 1; // 当前节点的高度
Node<T> left; // 左子树
Node<T> right; // 右子树
Node<T> parent; // 维护父节点,方便操作
public Node(T val) {
this.val = val;
}
public Node(T val, Node<T> parent) {
this.val = val;
this.parent = parent;
}
// 当前节点是否为左子树, 借助父节点的左子树是否等于当前节点进行判断
public boolean hasLeft() {
return parent != null && parent.left == this;
}
// 当前节点是否为右子树
public boolean hasRight() {
return parent != null && parent.right == this;
}
// 当前节点是否为平衡 | 高度 | <= 1
public boolean hasBalance() {
int lh = left != null ? left.height : 0;
int rh = right != null ? right.height : 0;
return Math.abs(lh - rh) < 2;
}
// 更新树的高度
public void updateHeight() {
int lh = left != null ? left.height : 0;
int rh = right != null ? right.height : 0;
height = Math.max(lh, rh) + 1;
}
// 找到当前节点最高的子节点
public Node<T> heightChild() {
int lh = left != null ? left.height : 0;
int rh = right != null ? right.height : 0;
return lh > rh ? left : right;
}
@Override
public String toString() {
return "Node{" +
"val=" + val +
", height=" + height +
", left=" + left +
", right=" + right +
'}';
}
}
定义AVL树类:
public class AVLTree<T> {
private Node<T> root; // 维护根节点
public void add(T t) {
if (root == null) {
root = new Node<>(t);
return;
}
// 找到该新值,适合插入的节点
Node<T> temp = null, cur = root;
while (cur != null) {
temp = cur;
if (compareTo(t, cur.val) > 0) {
cur = cur.right;
} else if (compareTo(t, cur.val) < 0) {
cur = cur.left;
} else {
return; // 值相等时, 逻辑由自己决定
}
}
Node<T> newNode = new Node<>(t, temp);
if (compareTo(t, temp.val) > 0) {
temp.right = newNode;
} else {
temp.left = newNode;
}
// AVL树,进行平衡操作
afterAddAndRemove(newNode);
}
public void remove(T t) {
Node<T> node = root;
while (node != null) {
if (compareTo(t, node.val) > 0) {
node = node.right;
} else if (compareTo(t, node.val) < 0) {
node = node.left;
} else {
break;
}
}
if (node == null) return;
// 如果是有两个节点
/**
* 删除 3
* 6 6 6
* / \ / \ / \
* 4 7 -> 3 7 -> 3 7
* / \ / \ / \
* 2 5 2 5 2 5
* \ \
* 3 4
*/
if (node.left != null && node.right != null) {
// 找到前驱节点,进行值覆盖。然后把4当作叶子节点删除。符合以下逻辑
Node<T> predecessor = predecessor(node);
node.val = predecessor.val;
node = predecessor; // node变为叶子节点
}
Node temp = node.left != null ? node.left : node.right;
// 如果是叶子节点,直接删除
/**
* 删除 3
* 5 5
* / \ / \
* 4 7 -> 4 7
* /
* 3
*/
if (temp == null) {
if (node.parent == null) { // 为根节点
root = null;
return;
}
if (node.hasLeft()) {
node.parent.left = null;
} else {
node.parent.right = null;
}
} else {
// 如果度为1 (只有一个节点), 让当前节点的父节点指向当前节点的子节点
/**
* 删除 4
* 5 5
* / \ / \
* 4 7 -> 3 7
* /
* 3
*/
temp.parent = node.parent;
if (node.hasLeft()) {
node.parent.left = temp;
} else {
node.parent.right = temp;
}
}
afterAddAndRemove(node);
node.parent = null; // GC
}
//前驱 , 左子树的右子树的右子树的右子树.... 直到为空
private Node<T> predecessor(Node node) {
Node<T> tar = null, temp = node.left;
while (temp != null) {
tar = temp;
temp = temp.right;
}
return tar;
}
private void afterAddAndRemove(Node node) {
Node<T> temp = node;
// 往父类传递,更新高度, 一直持续到根节点
while ((temp = temp.parent) != null) {
if (temp.hasBalance()) {
// 平衡,更新高度
temp.updateHeight();
} else {
// 不平衡,进行调整
rebalance(temp);
}
}
}
private void rebalance(Node grand) {
// 高的孩子的节点,产生不平衡
Node parent = grand.heightChild();
Node child = parent.heightChild();
// 判断是LL / RR / LR / RL
if (grand.left == parent) { // LL / LR
if (parent.left == child) { // LL
rotateRight(grand);
} else { // LR
rotateLeft(parent);
rotateRight(grand);
}
} else { // RR / RL
if (parent.right == child) { // RR
rotateLeft(grand);
} else { // RL
rotateRight(parent);
rotateLeft(grand);
}
}
grand.updateHeight();
parent.updateHeight();
}
// 右旋
private void rotateRight(Node grand) {
Node parent = grand.heightChild();
grand.left = parent.right;
parent.right = grand;
if (grand.left != null) {
grand.left.parent = grand;
}
parent.parent = grand.parent;
if (grand.hasLeft()) {
grand.parent.left = parent;
} else if (grand.hasRight()) {
grand.parent.right = parent;
} else {
root = parent;
}
grand.parent = parent;
}
// 左旋
private void rotateLeft(Node grand) {
Node parent = grand.heightChild();
grand.right = parent.left;
parent.left = grand;
if (grand.right != null) {
grand.right.parent = grand;
}
parent.parent = grand.parent;
if (grand.hasLeft()) {
grand.parent.left = parent;
} else if (grand.hasRight()) {
grand.parent.right = parent;
} else {
root = parent;
}
grand.parent = parent;
}
public int compareTo(T t1, T t2) {
return ((Comparable) t1).compareTo(t2);
}
@Override
public String toString() {
return "AVLTree{" +
"node=" + root +
'}';
}
}
测试代码:
public static void main(String[] args) {
AVLTree<Integer> avlTree = new AVLTree<>();
avlTree.add(5);
avlTree.add(1);
avlTree.add(3);
avlTree.add(2);
avlTree.add(6);
avlTree.add(8);
avlTree.add(4);
System.out.println(avlTree);
avlTree.remove(1);
System.out.println(avlTree);
}
运行结果:
AVLTree{node=Node{val=3, height=4, left=Node{val=1, height=2, left=null, right=Node{val=2, height=1, left=null, right=null}}, right=Node{val=6, height=3, left=Node{val=5, height=2, left=Node{val=4, height=1, left=null, right=null}, right=null}, right=Node{val=8, height=1, left=null, right=null}}}}
AVLTree{node=Node{val=5, height=3, left=Node{val=3, height=2, left=Node{val=2, height=1, left=null, right=null}, right=Node{val=4, height=1, left=null, right=null}}, right=Node{val=6, height=2, left=null, right=Node{val=8, height=1, left=null, right=null}}}}
关注公众号“程序员秋田君”,领取更多Java课程资料。 n