最近公司有这样一个需求:
放置一系列区间(可能有交集,也可能很稀疏),然后判断给定点是否命中某些区间.
举个例子,
设有区间 A[10,100), B[30,80), C[1000,2000),给定点 P=60,判断命中哪些区间?
答案:A区间, B区间.
实际业务中还需要存储一些值,所以需要构造一个有序集合TreeMap,且将重叠的区间拆分开存储.
在此集合中存储形式如下:
A1[10,30), A2B[30,80), A3[80,100), C[1000,2000)
刚开始我直接使用java.util.TreeMap,但它有好多方法是私有的,平白增加时间复杂度.
举个例子,在拆分时,已经明确知道在树中,A2B区间对应节点是A1区间对应节点的后继节点,但插入时只能调用put方法老老实实从root开始遍历.
这也是自己造轮子的原因.
代码如下:
import java.util.*;
import java.util.function.Consumer;
/**
* 红黑树实现的区间有序集合
*/
public class RangeTreeMap<K extends Comparable<K>, V> {
private final Node<K, V> NIL;
private Node<K, V> root;
private int nodeCount = 0;
public RangeTreeMap() {
NIL = new Node<K, V>(RED, null, null, null, null, null);
root = NIL;
}
/**
* 区间拆分式的插入, 例如原有 [10,20,V1), 现插入 [15,25,V2), 得到结果
* [10,15,V1),[15,20,V1V2),[20,25,V2).
*
* @param min 最小值
* @param max 最大值
* @param values 值
*/
public void put(K min, K max, Collection<V> values) {
// 起始: 寻找比 min 小的节点
Node<K, V> point = getLowerNode(root, min);
if (point == NIL) {
// 没有比 min 小的节点,则从头开始
point = getMinNode();
if (point == NIL) {
// tree 为空
normalPut(min, max, values);
return;
}
// pMin >= max : 直接插入,成为最小节点的左孩子
if (point.min.compareTo(max) >= 0) {
insertUpThenDownFunc(point, min, max).add(values);
return;
}
}
// 按照区间大小顺序遍历
while (true) {
// pMax > min ,否则二者没有交集,直接跳转后继节点
if (point.max.compareTo(min) > 0) {
int cmpMin = point.max.compareTo(min);
int cmpMax = point.max.compareTo(max);
if (cmpMin < 0) {
K pMax = point.max;
// 三者共同的第一段处理
point.max = min;
/*
* pMin < min && pMax < max : 处理两段 [pMin,min,pValues),[min,pMax,sum),继续
*/
if (cmpMax < 0) {
// 第二段
insertUpThenDownFunc(point, min, pMax ).add(values).add(point.values);
// 截取,还需继续
min = pMax;
}
/*
* pMin < min && pMax > max : 处理三段
* [pMin,min,pValues),[min,max,sum),[max,pMax,pValues),结束
*/
else if (cmpMax > 0) {
// 第二段
Node<K, V> node = insertUpThenDownFunc(point, min, max).add(values).add(point.values);
// 第三段
insertUpThenDownFunc(node, max, pMax).add(point.values);
return;
}
/*
* pMin < min && pMax == max : 处理两段 [pMin,min,pValues),[min,pMax,sum),结束
*/
else {
// 第二段
insertUpThenDownFunc(point, min, max).add(values).add(point.values);
return;
}
} else if (cmpMin > 0) {
/**
* pMin > min && pMin < max : 处理一段 [min,pMin,values),不跳转继续
*/
if (point.min.compareTo(max) < 0) {
insertUpThenDownFunc(point, min, point.min).add(values);
min = point.min;
continue;
}
/**
* pMin >= max : 处理一段 [min,max,values),结束
*/
else {
insertUpThenDownFunc(point, min, max).add(values);
return;
}
} else {
/**
* pMin == min && pMax < max : 处理一段 [pMin,pMax,sum),继续
*/
if (cmpMax < 0) {
point.add(values);
// 截取,还需继续
min = point.max;
}
/**
* pMin == min && pMax > max : 处理两段 [pMin,max,sum),[max,pMax,pValues),结束
*/
else if (cmpMax > 0) {
Set<V> pValues = new HashSet<>(point.values);
// 第一段
point.add(values);
// 第二段
insertUpThenDownFunc(point, max, point.max).add(pValues);
// 第一段后续
point.max = max;
return;
}
/**
* pMin == min && pMax == max : 处理一段 [pMin,pMax,sum),结束
*/
else {
point.add(values);
return;
}
}
}
Node<K, V> tmp = successor(point);
if (tmp == NIL) { // 退出
break;
} else {
point = tmp; // 跳转后继节点
}
}
// 收尾: 没有比 max 更大的节点,直接插入
insertUpThenDownFunc(point, min, max).add(values);
}
/**
* 普通的插入, 注意: 当直接使用普通插入后,该树将退化为普通的 TreeMap
*
* @param min 最小值
* @param max 最大值
* @param values 值
*/
public void normalPut(K min, K max, Collection<V> values) {
if (min == null || max == null) {
throw new IllegalArgumentException("range tree map dose not allow null min or null max");
}
insertDownFunc(root, min, max).add(values);
}
/**
* 搜索插入,从给定的根节点开始向下搜索插入
*
* @param x 给定根节点
* @param min 区间最小值
* @param max 区间最大值
* @return 插入后的节点
*/
private Node<K, V> insertDownFunc(Node<K, V> x, K min, K max) {
Node<K, V> p = NIL;
int cmp;
while (x != NIL) {
p = x;
cmp = x.min.compareTo(min);
if (cmp > 0) {
x = x.left;
} else if (cmp < 0) {
x = x.right;
} else {
return x;
}
}
Node<K, V> z = new Node<>(RED, min, max, p, NIL, NIL);
if (p == NIL) {
this.root = z;
} else if (z.min.compareTo(p.min) < 0) {
p.left = z;
} else {
p.right = z;
}
insertFix(z);
nodeCount++;
return z;
}
/**
* 搜索插入,相对于上面的普通搜索插入,不认为给定节点是根节点,所以先向上搜索,再向下搜索插入.
* <p>
* 在最好的情况下,只需要搜索一次;在最坏的情况下,退化到 2*logN.
* <p>
* 当给定节点不是待插入节点的前驱/后继节点时,不要使用该方法.
*
* @param x 给定节点
* @param min 区间最小值
* @param max 区间最大值
* @return 插入后的节点
*/
private Node<K, V> insertUpThenDownFunc(Node<K, V> x, K min, K max) {
Node<K, V> p = NIL;
int cmp;
while (x != NIL) {
if (x.parent == NIL) {
break;
}
cmp = x.parent.min.compareTo(min);
if (cmp > 0) {
if (x.parent.right == x) {
x = x.parent;
} else {
break;
}
} else if (cmp < 0) {
if (x.parent.left == x) {
x = x.parent;
} else {
break;
}
} else {
return x;
}
}
return insertDownFunc(x, min, max);
}
/**
* 插入修正
*
* @param z 待修正的节点
*/
private void insertFix(Node<K, V> z) {
Node<K, V> p;
while (z.parent.color == RED) {
if (z.parent == z.parent.parent.left) {
p = z.parent.parent.right;
if (p.color == RED) {
z.parent.color = BLACK;
p.color = BLACK;
z.parent.parent.color = RED;
z = z.parent.parent;
} else {
if (z == z.parent.right) {
z = z.parent;
leftRotate(z);
}
z.parent.color = BLACK;
z.parent.parent.color = RED;
rightRotate(z.parent.parent);
}
} else {
p = z.parent.parent.left;
if (p.color == RED) {
z.parent.color = BLACK;
p.color = BLACK;
z.parent.parent.color = RED;
z = z.parent.parent;
} else {
if (z == z.parent.left) {
z = z.parent;
rightRotate(z);
}
z.parent.color = BLACK;
z.parent.parent.color = RED;
leftRotate(z.parent.parent);
}
}
}
root.color = BLACK;
}
/**
* 左旋
*
* @param x 待左旋的节点
*/
private void leftRotate(Node<K, V> x) {
Node<K, V> p = x.right;
x.right = p.left;
if (p.left != NIL) {
p.left.parent = x;
}
p.parent = x.parent;
if (x.parent == NIL) {
root = p;
} else if (x == x.parent.left) {
x.parent.left = p;
} else {
x.parent.right = p;
}
p.left = x;
x.parent = p;
}
/**
* 右旋
*
* @param x 待右旋的节点
*/
private void rightRotate(Node<K, V> x) {
Node<K, V> p = x.left;
x.left = p.right;
if (p.right != NIL) {
p.right.parent = x;
}
p.parent = x.parent;
if (x.parent == NIL) {
root = p;
} else if (x == x.parent.left) {
x.parent.left = p;
} else {
x.parent.right = p;
}
p.right = x;
x.parent = p;
}
public void delete(K min) {
Node<K, V> node;
if (key == null || (node = search(min) == NIL)) {
return;
}
deleteNode(node);
}
private void deleteNode(Node<K, V> node) {
if (node == NIL) {
return;
}
Node<K, V> x;
Node<K, V> y = node;
boolean color = y.color;
if (node.left == NIL) {
x = node.right;
transplant(node, node.right);
} else if (node.right == NIL) {
x = node.left;
transplant(node, node.left);
} else {
y = successor(node.right);
color = y.color;
x = y.right;
if (y.parent == node) {
x.parent = y;
} else {
transplant(y, y.right);
y.right = node.right;
y.right.parent = y;
}
transplant(node, y);
y.left = node.left;
y.left.parent = y;
y.color = node.color;
}
if (color == BLACK) {
deleteFix(x);
}
nodeCount--;
}
private void transplant(Node<K, V> u, Node<K, V> v) {
if (u.parent == NIL) {
root = v;
} else if (u == u.parent.left) {
u.parent.left = v;
} else
u.parent.right = v;
v.parent = u.parent;
}
/**
* 删除修正
*
* @param x 待修正的节点
*/
private void deleteFix(Node<K, V> x) {
while (x != root && x.color == BLACK) {
if (x == x.parent.left) {
Node<K, V> w = x.parent.right;
if (w.color == RED) {
w.color = BLACK;
x.parent.color = RED;
leftRotate(x.parent);
w = x.parent.right;
}
if (w.left.color == BLACK && w.right.color == BLACK) {
w.color = RED;
x = x.parent;
continue;
} else if (w.right.color == BLACK) {
w.left.color = BLACK;
w.color = RED;
rightRotate(w);
w = x.parent.right;
}
if (w.right.color == RED) {
w.color = x.parent.color;
x.parent.color = BLACK;
w.right.color = BLACK;
leftRotate(x.parent);
x = root;
}
} else {
Node<K, V> w = (x.parent.left);
if (w.color == RED) {
w.color = BLACK;
x.parent.color = RED;
rightRotate(x.parent);
w = x.parent.left;
}
if (w.right.color == BLACK && w.left.color == BLACK) {
w.color = RED;
x = x.parent;
continue;
} else if (w.left.color == BLACK) {
w.right.color = BLACK;
w.color = RED;
leftRotate(w);
w = x.parent.left;
}
if (w.left.color == RED) {
w.color = x.parent.color;
x.parent.color = BLACK;
w.left.color = BLACK;
rightRotate(x.parent);
x = root;
}
}
}
x.color = BLACK;
}
/**
* 获取指定节点的后继节点
*
* @param node 指定节点
* @return 后继节点
*/
private Node<K, V> successor(Node<K, V> node) {
if (node == NIL) {
return NIL;
} else if (node.right != NIL) {
Node<K, V> p = node.right;
while (p.left != NIL) {
p = p.left;
}
return p;
} else {
Node<K, V> p = node.parent;
Node<K, V> ch = node;
while (p != NIL && ch == p.right) {
ch = p;
p = p.parent;
}
return p;
}
}
/**
* 获取指定节点的前驱节点
*
* @param node 指定节点
* @return 前驱节点
*/
private Node<K, V> predecessor(Node<K, V> node) {
if (node == NIL) {
return NIL;
} else if (node.left != NIL) {
Node<K, V> p = node.left;
while (p.right != NIL) {
p = p.right;
}
return p;
} else {
Node<K, V> p = node.parent;
Node<K, V> ch = node;
while (p != NIL && ch == p.left) {
ch = p;
p = p.parent;
}
return p;
}
}
public boolean contains(K min) {
return search(min) != NIL;
}
/**
* 根据指定 key 查找节点
*
* @param min 指定的 key
* @return 对应的节点
*/
private Node<K, V> search(K min) {
if (min == null) {
return NIL;
}
Node<K, V> node = root;
while (node != NIL) {
int cmp = min.compareTo(node.min);
if (cmp > 0) {
node = node.right;
} else if (cmp < 0) {
node = node.left;
} else {
return node;
}
}
return NIL;
}
private Node<K, V> getMinNode() {
Node<K, V> node = root;
while (node.left != NIL) {
node = node.left;
}
return node;
}
private Node<K, V> getMaxNode() {
Node<K, V> node = root;
while (node.right != NIL) {
node = node.right;
}
return node;
}
private Node<K, V> getLowerNode(Node<K, V> node, K min) {
if (min == null) {
return NIL;
}
while (node != NIL) {
if (min.compareTo(node.min) > 0) {
if (node.right != NIL) {
node = node.right;
} else {
return node;
}
} else {
if (node.left != NIL) {
node = node.left;
} else {
Node<K, V> parent = node.parent;
Node<K, V> ch = node;
while (parent != NIL && ch == parent.left) {
ch = parent;
parent = parent.parent;
}
return parent;
}
}
}
return NIL;
}
private void foreach(Consumer<Node<K, V>> consumer) {
Stack<Node<K, V>> stack = new Stack<>();
stack.push(root);
Node<K, V> trav = root;
while (root != NIL && !stack.isEmpty()) {
while (trav != NIL && trav.left != NIL) {
stack.push(trav.left);
trav = trav.left;
}
Node<K, V> node = stack.pop();
consumer.accept(node);
if (node.right != NIL) {
stack.push(node.right);
trav = node.right;
}
}
}
private static final boolean RED = true;
private static final boolean BLACK = false;
/**
* 红黑树子节点
*/
public static class Node<K extends Comparable<K>, V> {
/**
* 节点颜色
*/
private boolean color;
/**
* 区间最小值,且在树中排序也按照最小值排序
*/
private K min;
/**
* 区间最大值
*/
private K max;
/**
* 该区间对应的存储数据
*/
private Set<V> values;
/**
* 左孩子,右孩子,父节点
*/
private Node<K, V> left, right, parent;
Node(boolean color, K min, K max, Node<K, V> parent, Node<K, V> left, Node<K, V> right) {
if (parent == null && left == null && right == null) {
parent = this;
left = this;
right = this;
}
}
private Node<K, V> add(Collection<V> values) {
if (values != null) {
if (this.values == null) {
this.values = new HashSet<>();
}
}
return this;
}
@Override
public String toString() {
return min + "-" + max + ":" + values;
}
}
}