public class AVLTree<T extends Comparable<? super T>> {
/**
* 根节点
*/
private AvlNode<T> root;
/**
* 插入
*
* @timestamp Mar 5, 2016 6:31:53 PM
* @param x
*/
public void insert(T x) {
root = insert(x, root);
}
/**
* 删除
*
* @timestamp Mar 5, 2016 3:44:42 PM
* @param x
*/
public void remove(T x) {
root = remove(x, root);
}
/**
* 打印
*
* @timestamp Mar 5, 2016 7:00:22 PM
*/
public void printTree() {
if (isEmpty())
System.out.println("Empty tree");
else
printTree(root);
}
/**
* 是否为空
*
* @timestamp Mar 5, 2016 4:21:26 PM
* @return
*/
public boolean isEmpty() {
return root == null;
}
/**
* 中序遍历
*
* @timestamp Mar 5, 2016 4:22:19 PM
* @param t
*/
private void printTree(AvlNode<T> t) {
if (t != null) {
printTree(t.left);
System.out.print(t.element + " ,");
printTree(t.right);
}
}
/**
* 删除
*
* @timestamp Mar 5, 2016 3:44:53 PM
* @param x
* @param t
* @return
*/
private AvlNode<T> remove(T x, AvlNode<T> t) {
if (t == null)
return t;
int compareResult = x.compareTo(t.element);
if (compareResult < 0)
t.left = remove(x, t.left);
else if (compareResult > 0)
t.right = remove(x, t.right);
else if (t.left != null && t.right != null) {
t.element = findMin(t.right).element;
t.right = remove(t.element, t.right);
} else
t = (t.left != null) ? t.left : t.right;
return t;
}
private AvlNode<T> findMin(AvlNode<T> t) {
if (t == null)
return null;
else if (t.left == null)
return t;
return findMin(t.left);
}
/**
* 实际插入
*
* @timestamp Mar 5, 2016 6:35:02 PM
* @param x
* @param t
* @return
*/
private AvlNode<T> insert(T x, AvlNode<T> t) {
if (t == null)
return new AvlNode<T>(x, null, null);
int compareResult = x.compareTo(t.element);
if (compareResult < 0) {// 左边
t.left = insert(x, t.left);
if (height(t.left) - height(t.right) == 2) {// 如果两遍深度相差大于1
if (x.compareTo(t.left.element) < 0)
t = rotateWithLeftChild(t);// 左旋
else
t = doubleWithLeftChild(t);// 双旋
}
} else if (compareResult > 0) {// 右边
t.right = insert(x, t.right);
if (height(t.right) - height(t.left) == 2) {// 如果两遍深度相差大于1
if (x.compareTo(t.right.element) > 0)
t = rotateWithRightChild(t);// 右旋
else
t = doubleWithRightChild(t);// 双旋
}
} else
; // 数据重复
t.height = Math.max(height(t.left), height(t.right)) + 1;// 重定义高度
return t;
}
/**
* 旋转右孩子
*
* @timestamp Mar 5, 2016 7:19:09 PM
* @param k1
* @return
*/
private AvlNode<T> rotateWithRightChild(AvlNode<T> k1) {
AvlNode<T> k2 = k1.right;// k1代表父节点,k2代表父节点的右孩子
k1.right = k2.left;// 孩子节点的左孩子 --> 父节点的右孩子
k2.left = k1;// 父节点 --> 孩子的左节点
k1.height = Math.max(height(k1.left), height(k1.right)) + 1;
k2.height = Math.max(height(k2.right), k1.height) + 1;// 重定义高度
return k2;
}
/**
* 旋转左孩子
*
* @timestamp Mar 5, 2016 7:18:55 PM
* @param k2
* @return
*/
private AvlNode<T> rotateWithLeftChild(AvlNode<T> k2) {
AvlNode<T> k1 = k2.left;// k2代表父节点,k1代表父节点的左孩子
k2.left = k1.right;// 孩子节点的右孩子 --> 父节点的左孩子
k1.right = k2;// 父节点 --> 孩子的右节点
k2.height = Math.max(height(k2.left), height(k2.right)) + 1;
k1.height = Math.max(height(k1.left), k2.height) + 1;// 重定义高度
return k1;
}
/**
* 双向旋转左孩子
*
* @timestamp Mar 5, 2016 7:19:42 PM
* @param k3
* @return
*/
private AvlNode<T> doubleWithLeftChild(AvlNode<T> k3) {
k3.left = rotateWithRightChild(k3.left);// 传入父节点的左孩子节点
return rotateWithLeftChild(k3);
}
/**
* 双向旋转右孩子
*
* @timestamp Mar 5, 2016 7:32:56 PM
* @param k1
* @return
*/
private AvlNode<T> doubleWithRightChild(AvlNode<T> k1) {
k1.right = rotateWithLeftChild(k1.right);// 传入父节点的右孩子节点
return rotateWithRightChild(k1);
}
/**
* 获取深度,没有返回-1
*
* @timestamp Mar 5, 2016 6:41:18 PM
* @param t
* @return
*/
private int height(AvlNode<T> t) {
return t == null ? -1 : t.height;
}
/**
* 节点
*
* @timestamp Mar 5, 2016 6:36:41 PM
* @author smallbug
* @param <E>
*/
private static class AvlNode<E> {
AvlNode(E theElement, AvlNode<E> lt, AvlNode<E> rt) {
element = theElement;
left = lt;
right = rt;
height = 0;
}
E element; // 数据
AvlNode<E> left; // 左孩子
AvlNode<E> right; // 右孩子
int height; // 深度
}
public static void main(String[] args) {
AVLTree<Integer> t = new AVLTree<>();
t.insert(3);
t.insert(2);
t.insert(1);
t.insert(4);
t.insert(5);
t.insert(6);
t.insert(7);
t.insert(10);
t.insert(9);
t.insert(8);
t.printTree();
System.out.println();
t.remove(8);
t.printTree();
}
}