public class Avl<T extends Comparable<? super T>> {
private AvlNode root;
private static final int ALLOWED_IMBALANCE = 1;
public Avl(){
this.root=null;
}
private static class AvlNode<T>{
T element;
AvlNode<T> left;
AvlNode<T> right;
int height;
AvlNode(T theElement){
this(theElement,null,null);
}
AvlNode(T theElement,AvlNode<T> lt,AvlNode<T> rt){
left=lt;
right=rt;
element=theElement;
height=0;
}
}
public void insert(T x){
root=insert(x,root);
}
private AvlNode<T> insert(T x,AvlNode<T> t){
if(t==null){
return new AvlNode<>(x,null,null);
}
int compareResult=x.compareTo(t.element);
if(compareResult<0){
t.left=insert(x,t.left);
}else if(compareResult>0){
t.right=insert(x,t.right);
}else{
}
return balance(t);
}
private AvlNode<T> balance(AvlNode<T> t){
if(t==null){
return t;
}
if((height(t.left)-height(t.right))>ALLOWED_IMBALANCE){
if(height(t.left.left)>=height(t.left.right)){
t=rotateWithLeftChild(t);
}else{
t=doubleWithLeftChild(t);
}
}else if((height(t.right)-height(t.left))>ALLOWED_IMBALANCE){
if(height(t.right.right)>=height(t.right.left)){
t=rotateWithRightChild(t);
}else{
t=doubleWithRightChild(t);
}
}
t.height = Math.max( height( t.left ), height( t.right ) ) + 1;
return t;
}
private AvlNode<T> rotateWithLeftChild(AvlNode<T> k2){
AvlNode<T> k1=k2.left;
k2.left=k1.right;
k1.right=k2;
k2.height=Math.max(height(k2.left),height(k2.right))+1;
k1.height=Math.max(height(k1.left),k2.height)+1;
return k1;
}
private AvlNode<T> rotateWithRightChild(AvlNode<T> k1){
AvlNode<T> k2=k1.right;
k1.right=k2.left;
k2.left=k1;
k1.height=Math.max(height(k1.left),height(k1.right))+1;
k2.height=Math.max(height(k2.right),k1.height)+1;
return k2;
}
private AvlNode<T> doubleWithLeftChild(AvlNode<T> k3){
k3.left=rotateWithRightChild(k3.left);
return rotateWithLeftChild(k3);
}
private AvlNode<T> doubleWithRightChild(AvlNode<T> k1){
k1.right=rotateWithLeftChild(k1.right);
return rotateWithRightChild(k1);
}
private int height(AvlNode<T> t){
return t==null?-1:t.height;
}
public void remove( T x ) {
root = remove( x, root );
}
private AvlNode<T> remove(T x,AvlNode<T> t){
if(t==null){
return null;
}
int compareResult = x.compareTo( t.element );
if(compareResult<0){
t.left=remove(x,t.left);
}else if(compareResult>0){
t.right=remove(x,t.right);
}else if(t.left!=null && t.right!=null){
t.element=findMin(t.right).element;
t.right=remove(t.element,t.right);
}else{
t=(t.left!=null)?t.left:t.right;
}
return balance(t);
}
private AvlNode<T> findMin(AvlNode<T> t){
if(t==null){
return null;
}
while (t.left!=null){
t=t.left;
}
return t;
}
private AvlNode<T> findMax(AvlNode<T> t){
if(t==null){
return null;
}
while (t.right!=null){
t=t.right;
}
return t;
}
public boolean contains(T x){
return contains(x,root);
}
private boolean contains(T x,AvlNode<T> t){
while (t!=null){
int compareResult=x.compareTo(t.element);
if(compareResult<0){
t=t.left;
}else if(compareResult>0){
t=t.right;
}else{
return true;
}
}
return false;
}
private void printTree(AvlNode<T> t){
if(t!=null){
printTree(t.left);
System.out.println(t.element);
printTree(t.right);
}
}
}
-
参考《数据结构与算法分析Java语言描述》