public class AVL 树 {
public Node root = null;
public void insert ( int key) {
if ( root == null) {
root = new Node ( key, null) ;
return ;
}
Node cur = root;
Node parent = null;
while ( cur!= null) {
if ( key== ( cur. key) ) {
throw new RuntimeException ( "key重复了 " + key) ;
}
else if ( key< cur. key) {
parent = cur;
cur = cur. left;
} else {
parent = cur;
cur = cur. right;
}
}
if ( key< parent. key) {
parent. left = new Node ( key, parent) ;
cur = parent. left;
} else {
parent. right = new Node ( key, parent) ;
cur = parent. right;
}
while ( true ) {
if ( cur == parent. left) {
parent. bf++ ;
} else {
parent. bf-- ;
}
if ( parent. bf== 0 ) {
break ;
} else if ( parent. bf== 2 ) {
if ( cur. bf== 1 ) {
fixll ( parent) ;
} else {
fixlr ( parent) ;
}
break ;
} else if ( parent. bf== - 2 ) {
if ( cur. bf== - 1 ) {
fixrr ( parent) ;
} else {
fixrl ( parent) ;
}
break ;
} else if ( parent == root) {
break ;
}
cur = parent;
parent = cur. parent;
}
}
private void leftRotate ( Node parent) {
Node cur = parent. right;
Node pp = parent. parent;
Node cc = cur. left;
if ( pp!= null) {
if ( parent== pp. left) {
pp. left = cur;
cur. parent = pp;
} else {
pp. right = cur;
cur. parent = pp;
}
} else {
root = cur;
cur. parent = null;
}
if ( cc!= null) {
parent. right = cc;
cc. parent = parent;
} else {
parent. right = null;
}
parent. parent = cur;
cur. left = parent;
}
private void rightRotate ( Node parent) {
Node pp = parent. parent;
Node cur = parent. left;
Node cc = cur. right;
if ( pp!= null) {
if ( parent== pp. left) {
pp. left = cur;
cur. parent = pp;
} else {
pp. right = cur;
cur. parent = pp;
}
} else {
root = cur;
cur. parent = null;
}
cur. right = parent;
parent. parent = cur;
if ( cc!= null) {
parent. left = cc;
cc. parent = parent;
} else {
parent. left= null;
}
}
private void fixrl ( Node parent) {
Node node = parent;
Node right = parent. right;
Node left = right. left;
rightRotate ( right) ;
leftRotate ( parent) ;
if ( left. bf== 1 ) {
node. bf = 0 ;
right. bf= - 1 ;
left. bf = 0 ;
} else if ( left. bf== - 1 ) {
node. bf = 1 ;
right. bf= 0 ;
left. bf = 0 ;
} else {
node. bf = 0 ;
left. bf = 0 ;
right. bf= 0 ;
}
}
private void fixlr ( Node parent) {
Node node = parent;
Node left = parent. left;
Node right = left. right;
leftRotate ( left) ;
rightRotate ( parent) ;
if ( right. bf== 1 ) {
node. bf = - 1 ;
left. bf = 0 ;
right. bf= 0 ;
} else if ( right. bf== - 1 ) {
node. bf = 0 ;
left. bf = 1 ;
right. bf= 0 ;
} else {
node. bf = 0 ;
left. bf = 0 ;
right. bf= 0 ;
}
}
private void fixrr ( Node parent) {
Node node = parent;
Node right = parent. right;
leftRotate ( parent) ;
node. bf = right. bf = 0 ;
}
private void fixll ( Node parent) {
Node node = parent;
Node left = parent. left;
rightRotate ( parent) ;
node. bf = left. bf = 0 ;
}
public boolean contains ( int key) {
Node cur = root;
while ( cur!= null) {
if ( key== ( cur. key) ) {
return true ;
}
else if ( key< cur. key) {
cur = cur. left;
} else {
cur = cur. right; }
}
return false ;
}
}
import java. util. ArrayList;
import java. util. Collections;
import java. util. List;
import java. util. Random;
public class Node {
int key;
int bf;
Node left;
Node right;
Node parent ;
public Node ( int key, Node parent) {
this . key = key;
this . bf = 0 ;
this . left = null;
this . right = null;
this . parent = parent;
}
}
class Main {
public static void main ( String[ ] args) {
Random random = new Random ( ) ;
AVL树 tree = new AVL 树( ) ;
for ( int i = 0 ; i < 20 ; i++ ) {
try {
tree. insert ( random. nextInt ( 100000 ) ) ;
} catch ( RuntimeException e) {
System. out. println ( e. getMessage ( ) ) ;
}
}
verify ( tree) ;
}
public static void verify ( AVL树 tree) {
List< Integer> list = new ArrayList < > ( ) ;
中序遍历( tree. root, list) ;
List< Integer> list1= new ArrayList < > ( list) ;
Collections. sort ( list1) ;
if ( ! list1. equals ( list) ) {
throw new RuntimeException ( "该树中序遍历无序" ) ;
} else {
System. out. println ( "1.中序遍历有序" ) ;
}
try {
bf ( tree. root) ;
} catch ( RuntimeException e) {
System. out. println ( "平衡因子算错" ) ;
}
System. out. println ( "2.平衡因子正确" ) ;
}
public static void 中序遍历( Node root, List< Integer> list) {
if ( root== null) return ;
Node cur = root;
中序遍历( cur. left, list) ;
list. add ( cur. key) ;
中序遍历( cur. right, list) ;
}
public static void bf ( Node root) {
if ( root== null) { return ; }
int left = height ( root. left) ;
int right = height ( root. right) ;
if ( left- right!= root. bf|| root. bf> 1 || root. bf< - 1 ) {
throw new RuntimeException ( ) ;
}
bf ( root. left) ;
bf ( root. right) ;
}
public static int height ( Node root) {
if ( root== null) {
return 0 ;
}
int a = height ( root. right) ;
int b = height ( root. left) ;
return Math. max ( a, b) + 1 ;
}
}