SBT 和AVL类似,只是是基于size 进行balance的,旋转的case是一样的。
public class AVLOrSBTTreeMap<Key extends Comparable<Key>, Value> {
public enum BF {HEIGHT, SIZE, NONE}; //balance factor
private BF bf = BF.HEIGHT;
public AVLOrSBTTreeMap() { }
public AVLOrSBTTreeMap(BF balanceFactor) {this.bf = balanceFactor;}
private class Node {
Key key;
Value value;
Node left, right;
int ht, sz;
Node(Key k, Value v) {
this.key = k;
this.value = v;
this.ht = 1;
this.sz = 1;
}
}
private Node root;
private int height(Node root) {
if (root == null) return 0;
return root.ht;
}
private int size(Node root) {
if (root == null) return 0;
return root.sz;
}
private void updateHeightAndSize(Node root) {
root.sz = size(root.left) + size(root.right) + 1;
root.ht = Math.max(height(root.left), height(root.right)) + 1;
}
private Node rotateRight(Node root) {
Node l = root.left;
root.left = l.right;
updateHeightAndSize(root);
l.right = root;
updateHeightAndSize(l);
return l;
}
private Node rotateLeft(Node root) {
Node r = root.right;
root.right = r.left;
updateHeightAndSize(root);
r.left = root;
updateHeightAndSize(r);
return r;
}
private int heightDiff(Node root) {
return height(root.left) - height(root.right);
}
private Node maintain(Node root) {
if (bf == BF.HEIGHT && heightDiff(root) > 1 && heightDiff(root.left) >= 0 ||
bf == BF.SIZE && root.left != null && size(root.left.left) > size(root.right))
{ // LL
root = rotateRight(root);
}
else if (bf == BF.HEIGHT && heightDiff(root) > 1 && heightDiff(root.left) < 0 ||
bf == BF.SIZE && root.left != null && size(root.left.right) > size(root.right))
{ // LR
root.left = rotateLeft(root.left);
root = rotateRight(root);
}
else if (bf == BF.HEIGHT && heightDiff(root) < -1 && heightDiff(root.right) <= 0 ||
bf == BF.SIZE && root.right != null && size(root.right.right) > size(root.left))
{ // RR
root = rotateLeft(root);
}
else if (bf == BF.HEIGHT && heightDiff(root) < -1 && heightDiff(root.right) > 0 ||
bf == BF.SIZE && root.right != null && size(root.right.left) > size(root.left))
{ //RL
root.right = rotateRight(root.right);
root = rotateLeft(root);
}
updateHeightAndSize(root);
return root;
}
private Node put(Node root, Key key, Value value) {
if (root == null) return new Node(key, value);
if (key.compareTo(root.key) < 0) root.left = put(root.left, key, value);
else if (key.compareTo(root.key) > 0) root.right = put(root.right, key, value);
else root.value = value;
return maintain(root);
}
public void put(Key key, Value value) {
root = put(root, key, value);
}
private Node min(Node root) {
if (root.left == null) return root;
return min(root.left);
}
private Node remove(Node root, Key key) {
if (root == null) return null;
if (key.compareTo(root.key) < 0) {
root.left = remove(root.left, key);
}
else if (key.compareTo(root.key) > 0) {
root.right = remove(root.right, key);
}
else {
if (root.left == null) return root.right;
else if (root.right == null) return root.left;
else {
Node successor = min(root.right);
Key tempKey = root.key;
root.key = successor.key;
root.value = successor.value;
successor.key = tempKey;
root.right = remove(root.right, tempKey);
}
}
return maintain(root);
}
public void remove(Key key) {
root = remove(root, key);
}
private int rank(Node root, Key key) {
if (root == null) return 0;
if (key.compareTo(root.key) == 0) return size(root.left);
if (key.compareTo(root.key) < 0) return rank(root.left, key);
return size(root.left) + 1 + rank(root.right, key);
}
public int rank(Key key) {
return rank(root, key);
}
private Node select(Node root, int rank) {
if (root == null) return null;
if (rank == size(root.left)) return root;
if (rank < size(root.left)) return select(root.left, rank);
return select(root.right, rank - size(root.left) - 1);
}
public Key select(int rank) {
Node x = select(root, rank);
if (x == null) return null;
return x.key;
}
private Node floor(Node root, Key key) {
if (root == null) return null;
int cmp = key.compareTo(root.key);
if (cmp == 0) return root;
if (cmp < 0) return floor(root.left, key);
Node f = floor(root.right, key);
if (f == null) return root;
return f;
}
public Key floor(Key key) {
Node f = floor(root, key);
if (f == null) return null;
return f.key;
}
private Node ceiling(Node root, Key key) {
if (root == null) return null;
int cmp = key.compareTo(root.key);
if (cmp == 0) return root;
if (cmp > 0) return ceiling(root.right, key);
Node ceil = ceiling(root.left, key);
if (ceil == null) return root;
return ceil;
}
public Key ceiling(Key key) {
Node x = ceiling(root, key);
if (x == null) return null;
return x.key;
}
private Value get(Node root, Key key) {
if (root == null) return null;
int cmp = key.compareTo(root.key);
if (cmp == 0) return root.value;
if (cmp < 0) return get(root.left, key);
return get(root.right, key);
}
public Value get(Key key) { return get(root, key);}
private Node copy(Node root) {
if (root == null) return null;
Node x = new Node(root.key, root.value);
x.ht = root.ht;
x.sz = root.sz;
x.left = copy(root.left);
x.right = copy(root.right);
return x;
}
public AVLOrSBTTreeMap<Key, Value> copy() {
AVLOrSBTTreeMap<Key, Value> clone = new AVLOrSBTTreeMap<Key, Value>();
clone.root = copy(this.root);
return clone;
}