又写了一遍红黑树,写了两遍其实感觉也不是很复杂,旋转的代码写完了并且写正确了之后,手写红黑树还不是信手拈来。
一、设计思路
1、设计结构图
2、接口代码
public interface Set<E> {
int size();
boolean isEmpty();
void clear();
boolean contains(E element);
void add(E element);
void remove(E element);
void traversal(Visitor<E> visitor);
/**
* 供外部执行
*
* @param <E>
*/
public static abstract class Visitor<E> {
public abstract void visit(E element);
}
}
public interface Map<K, V> {
int size();
boolean isEmpty();
void clear();
V put(K key, V value);
V get(K key);
V remove(K key);
boolean containsKey(K key);
boolean containsValue(V value);
void traversal(Visitor<K, V> visitor);
/**
* 供外部执行
*
* @param <K>
* @param <V>
*/
public static abstract class Visitor<K, V> {
public abstract void visit(K key, V value);
}
}
二、旋转图解
1、左旋转
2、右旋转
三、代码
TreeMap代码:
public class TreeMap<K, V> implements Map<K, V> {
private Comparator<K> comparator;
private static final int RED = 1;
private static final int BLACK = 0;
private Node<K, V> root;
private int size = 0;
public TreeMap() {
this(null);
}
public TreeMap(Comparator<K> comparator) {
this.comparator = comparator;
}
private static class Node<K, V> {
Node<K, V> left;
Node<K, V> right;
Node<K, V> parent;
K key;
V value;
int color = RED;
public Node(Node<K, V> parent, K key, V value) {
this.parent = parent;
this.key = key;
this.value = value;
}
@Override
public String toString() {
StringBuilder str = new StringBuilder();
if (this.color == RED)
str.append("Red_");
else
str.append("Black_");
str.append(this.key);
str.append("--");
str.append(this.value);
return str.toString();
}
}
private void ToColor(Node<K, V> node, int color) {
node.color = color;
}
private void ToRed(Node<K, V> node) {
node.color = RED;
}
private void ToBlack(Node<K, V> node) {
node.color = BLACK;
}
private int colorOf(Node<K, V> node) {
return node.color;
}
private boolean isRed(Node<K, V> node) {
return node == null ? false : node.color == RED;
}
private boolean isBlack(Node<K, V> node) {
return node == null ? true : node.color == BLACK;
}
private int compare(K k1, K k2) {
if (comparator != null) {
return comparator.compare(k1, k2);
}
return ((Comparable<K>) k1).compareTo(k2);
}
@Override
public int size() {
return this.size;
}
@Override
public boolean isEmpty() {
return this.size == 0;
}
@Override
public void clear() {
this.size = 0;
root = null;
}
@Override
public V put(K key, V value) {
if (root == null) {
Node<K, V> newNode = new Node<>(null, key, value);
root = newNode;
afterPut(newNode);
size++;
return null;
}
Node<K, V> node = root;
Node<K, V> temp = null;
int cmp = -1;
while (node != null) {
cmp = compare(key, node.key);
temp = node;
if (cmp > 0)
node = node.right;
else if (cmp < 0)
node = node.left;
else
break;
}
V old = null;
if (node != null)
old = node.value;
if (cmp > 0) {
temp.right = new Node<K, V>(temp, key, value);
afterPut(temp.right);
} else if (cmp < 0) {
temp.left = new Node<K, V>(temp, key, value);
afterPut(temp.left);
} else {
//相等就覆盖
node.key = key;
node.value = value;
}
return old;
}
private void afterPut(Node<K, V> node) {
Node<K, V> parent = node.parent;
if (node == root || parent == null) { //根结点直接染黑
ToBlack(node);
return;
}
if (isBlack(parent)) { //父结点是黑色直接添加
return;
}
Node<K, V> grandparent = parent.parent;
Node<K, V> uncle = grandparent.left == parent ? grandparent.right : grandparent.left;
if (isRed(uncle)) { //上溢
ToBlack(parent);
ToBlack(uncle);
ToRed(grandparent);
afterPut(grandparent);
return;
} else { //超级结点内旋
ToRed(grandparent);
if (grandparent.left == parent) {
if (parent.left == node) { //LL
ToBlack(parent);
rotateRight(grandparent);
} else { //LR
ToBlack(node);
rotateLeft(parent);
rotateRight(grandparent);
}
} else {
if (parent.left == node) { //RL
ToBlack(node);
rotateRight(parent);
rotateLeft(grandparent);
} else { //RR
ToBlack(parent);
rotateLeft(grandparent);
}
}
}
}
private void rotateLeft(Node<K, V> grandparent) {
Node<K, V> parent = grandparent.right;
Node<K, V> node = parent.right;
grandparent.right = parent.left;
parent.left = grandparent;
if (grandparent.right != null) //左旋转时parent的左子树可能为空
grandparent.right.parent = grandparent;
parent.parent = grandparent.parent;
if (root == grandparent) //改变grandparent父结点的指向
root = parent;
else if (grandparent.parent.left == grandparent)
grandparent.parent.left = parent;
else
grandparent.parent.right = parent;
grandparent.parent = parent;
}
private void rotateRight(Node<K, V> grandparent) {
Node<K, V> parent = grandparent.left;
Node<K, V> node = parent.left;
grandparent.left = parent.right;
parent.right = grandparent;
if (grandparent.left != null) //右旋转时parent的右子树可能为空
grandparent.left.parent = grandparent;
parent.parent = grandparent.parent;
if (root == grandparent) //改变grandparent父结点的指向
root = parent;
else if (grandparent.parent.left == grandparent)
grandparent.parent.left = parent;
else
grandparent.parent.right = parent;
grandparent.parent = parent;
}
private Node<K, V> getNode(K key) {
Node<K, V> node = root;
while (node != null) {
int cmp = compare(key, node.key);
if (cmp > 0)
node = node.right;
else if (cmp < 0)
node = node.left;
else
return node;
}
return null;
}
/**
* 找到前驱结点: 1、左儿子的最右边; 2、第一个比它小的父结点
*
* @param node
* @return
*/
private Node<K, V> getPrecursor(Node<K, V> node) {
if (node == null)
return null;
Node<K, V> p = node.left;
if (p != null) {
while (p.right != null) {
p = p.right;
}
return p;
}
Node<K, V> parent = node;
Node<K, V> temp = null;
while (parent != null) {
temp = parent;
parent = parent.parent;
if (parent.right == temp)
break;
}
return parent;
}
/**
* 找到后继结点: 1、右儿子的最左边; 2、第一个比它大的父结点
*
* @param node
* @return
*/
private Node<K, V> getSuccessor(Node<K, V> node) {
if (node == null)
return null;
Node<K, V> p = node.right;
if (p != null) {
while (p.left != null) {
p = p.left;
}
return p;
}
Node<K, V> parent = node;
Node<K, V> temp = null;
while (parent != null) {
temp = parent;
parent = parent.parent;
if (parent.left == temp)
break;
}
return parent;
}
@Override
public V get(K key) {
Node<K, V> node = getNode(key);
return node == null ? null : node.value;
}
@Override
public V remove(K key) {
Node<K, V> node = getNode(key);
if (node == null)
return null;
if (node.left != null && node.right != null) { //删除的是度为2的结点
Node<K, V> des = getPrecursor(node);
node.key = des.key;
node.value = des.value;
node = des;
}
if (node.left == null && node.right == null) { //删除度为0的结点
if (node.parent == null) {
root = null;
} else if (node.parent.left == node) {
node.parent.left = null;
} else {
node.parent.right = null;
}
afterRemove(node);
} else { //删除度为1的结点
Node<K, V> replace = node.left == null ? node.right : node.left;
if (node.parent == null) {
root = replace;
} else if (node.parent.left == node) {
node.parent.left = replace;
} else {
node.parent.right = replace;
}
replace.parent = node.parent;
afterRemove(replace);
}
return null;
}
private void afterRemove(Node<K, V> node) {
if (isRed(node)) { //被删除的是红色结点或者替代结点为红色
ToBlack(node);
return;
}
if (node == null)
return;
Node<K, V> parent = node.parent;
if (parent == null)
return;
boolean removeOnLeft = node.parent.left == null || node.parent.left == node;
Node<K, V> sibling = removeOnLeft ? node.parent.right : node.parent.left;
if (removeOnLeft) { //被删除的结点在左边
if (isRed(sibling)) { //侄子转成兄弟
ToBlack(sibling);
ToRed(parent);
rotateLeft(parent);
sibling = parent.right;
}
if (isBlack(sibling.left) && isBlack(sibling.right)) { //下溢
ToRed(sibling);
if (isBlack(parent)) {
ToBlack(parent);
afterRemove(parent);
return;
} else {
ToBlack(parent);
}
} else if (isRed(sibling.right)) { //RR
ToColor(sibling, colorOf(parent));
ToBlack(parent);
ToBlack(sibling.right);
rotateLeft(parent);
} else { //RL
ToColor(sibling.left, colorOf(parent));
ToBlack(parent);
ToBlack(sibling);
rotateRight(sibling);
rotateLeft(parent);
}
} else { //被删除的结点在右边
if (isRed(sibling)) { //侄子转成兄弟
ToBlack(sibling);
ToRed(parent);
rotateRight(parent);
sibling = parent.left;
}
if (isBlack(sibling.left) && isBlack(sibling.right)) { //下溢
ToRed(sibling);
if (isBlack(parent)) {
ToBlack(parent);
afterRemove(parent);
return;
} else {
ToBlack(parent);
}
} else if (isRed(sibling.left)) { //LL
ToColor(sibling, colorOf(parent));
ToBlack(parent);
ToBlack(sibling.left);
rotateRight(parent);
} else { //LR
ToColor(sibling.right, colorOf(parent));
ToBlack(parent);
ToBlack(sibling);
rotateLeft(sibling);
rotateRight(parent);
}
}
}
@Override
public boolean containsKey(K key) {
return getNode(key) != null;
}
private boolean compareValue(V v1, V v2) {
return v1 == null ? v2 == null : v1.equals(v2);
}
/**
* 二叉树的层次遍历判断是否包含value
*
* @param value
* @return
*/
@Override
public boolean containsValue(V value) {
Queue<Node<K, V>> queue = new LinkedList<>();
queue.offer(root);
while (!queue.isEmpty()) {
Node<K, V> node = queue.remove();
if (compareValue(node.value, value))
return true;
if (node.left != null)
queue.offer(node.left);
if (node.right != null)
queue.offer(node.right);
}
return false;
}
private void traversal(Node<K, V> node, Visitor<K, V> visitor) {
if (node == null)
return;
traversal(node.left, visitor);
visitor.visit(node.key, node.value);
traversal(node.right, visitor);
}
@Override
public void traversal(Visitor<K, V> visitor) {
traversal(root, visitor);
}
}
TreeSet代码:
public class TreeSet<E> implements Set<E> {
private TreeMap<E, Object> tree;
public TreeSet() {
tree = new TreeMap<>();
}
@Override
public int size() {
return tree.size();
}
@Override
public boolean isEmpty() {
return tree.isEmpty();
}
@Override
public void clear() {
tree.clear();
}
@Override
public boolean contains(E element) {
return tree.containsKey(element);
}
@Override
public void add(E element) {
tree.put(element, null);
}
@Override
public void remove(E element) {
tree.remove(element);
}
@Override
public void traversal(Visitor<E> visitor) {
tree.traversal(new Map.Visitor<E, Object>() {
@Override
public void visit(E key, Object value) {
System.out.println(key);
}
});
}
}