利用红黑树实现的线段/区间有序集合

最近公司有这样一个需求:

放置一系列区间(可能有交集,也可能很稀疏),然后判断给定点是否命中某些区间.


举个例子,

设有区间 A[10,100), B[30,80), C[1000,2000),给定点 P=60,判断命中哪些区间?

答案:A区间, B区间.


实际业务中还需要存储一些值,所以需要构造一个有序集合TreeMap,且将重叠的区间拆分开存储.

在此集合中存储形式如下:

A1[10,30), A2B[30,80), A3[80,100), C[1000,2000)

刚开始我直接使用java.util.TreeMap,但它有好多方法是私有的,平白增加时间复杂度.

举个例子,在拆分时,已经明确知道在树中,A2B区间对应节点是A1区间对应节点的后继节点,但插入时只能调用put方法老老实实从root开始遍历.

这也是自己造轮子的原因.


代码如下:

import java.util.*;
import java.util.function.Consumer;

/**
 * 红黑树实现的区间有序集合
 */
public class RangeTreeMap<K extends Comparable<K>, V> {

    private final Node<K, V> NIL;

    private Node<K, V> root;
    private int nodeCount = 0;

    public RangeTreeMap() {
        NIL = new Node<K, V>(RED, null, null, null, null, null);
        root = NIL;
    }

    /**
     * 区间拆分式的插入, 例如原有 [10,20,V1), 现插入 [15,25,V2), 得到结果
     * [10,15,V1),[15,20,V1V2),[20,25,V2).
     * 
     * @param min    最小值
     * @param max    最大值
     * @param values 值
     */
    public void put(K min, K max, Collection<V> values) {
        // 起始: 寻找比 min 小的节点
        Node<K, V> point = getLowerNode(root, min);
        if (point == NIL) {
            // 没有比 min 小的节点,则从头开始
            point = getMinNode();
            if (point == NIL) {
                // tree 为空
                normalPut(min, max, values);
                return;
            }
            // pMin >= max : 直接插入,成为最小节点的左孩子
            if (point.min.compareTo(max) >= 0) {
                insertUpThenDownFunc(point, min, max).add(values);
                return;
            }
        }
        // 按照区间大小顺序遍历
        while (true) {
            // pMax > min ,否则二者没有交集,直接跳转后继节点
            if (point.max.compareTo(min) > 0) {
                int cmpMin = point.max.compareTo(min);
                int cmpMax = point.max.compareTo(max);
                if (cmpMin < 0) {
                    K pMax = point.max;
                    // 三者共同的第一段处理
                    point.max = min;
                    /*
                     * pMin < min && pMax < max : 处理两段 [pMin,min,pValues),[min,pMax,sum),继续
                     */
                    if (cmpMax < 0) {
                        // 第二段
                        insertUpThenDownFunc(point, min, pMax ).add(values).add(point.values);
                        // 截取,还需继续
                        min = pMax;
                    }
                    /*
                     * pMin < min && pMax > max : 处理三段
                     * [pMin,min,pValues),[min,max,sum),[max,pMax,pValues),结束
                     */
                    else if (cmpMax > 0) {
                        // 第二段
                        Node<K, V> node = insertUpThenDownFunc(point, min, max).add(values).add(point.values);
                        // 第三段
                        insertUpThenDownFunc(node, max, pMax).add(point.values);
                        return;
                    }
                    /*
                     * pMin < min && pMax == max : 处理两段 [pMin,min,pValues),[min,pMax,sum),结束
                     */
                    else {
                        // 第二段
                        insertUpThenDownFunc(point, min, max).add(values).add(point.values);
                        return;
                    }
                } else if (cmpMin > 0) {
                    /**
                     * pMin > min && pMin < max : 处理一段 [min,pMin,values),不跳转继续
                     */
                    if (point.min.compareTo(max) < 0) {
                        insertUpThenDownFunc(point, min, point.min).add(values);
                        min = point.min;
                        continue;
                    }
                    /**
                     * pMin >= max : 处理一段 [min,max,values),结束
                     */
                    else {
                        insertUpThenDownFunc(point, min, max).add(values);
                        return;
                    }
                } else {
                    /**
                     * pMin == min && pMax < max : 处理一段 [pMin,pMax,sum),继续
                     */
                    if (cmpMax < 0) {
                        point.add(values);
                        // 截取,还需继续
                        min = point.max;
                    }
                    /**
                     * pMin == min && pMax > max : 处理两段 [pMin,max,sum),[max,pMax,pValues),结束
                     */
                    else if (cmpMax > 0) {
                        Set<V> pValues = new HashSet<>(point.values);
                        // 第一段
                        point.add(values);
                        // 第二段
                        insertUpThenDownFunc(point, max, point.max).add(pValues);
                        // 第一段后续
                        point.max = max;
                        return;
                    }
                    /**
                     * pMin == min && pMax == max : 处理一段 [pMin,pMax,sum),结束
                     */
                    else {
                        point.add(values);
                        return;
                    }
                }
            }
            Node<K, V> tmp = successor(point);
            if (tmp == NIL) { // 退出
                break;
            } else {
                point = tmp; // 跳转后继节点
            }
        }
        // 收尾: 没有比 max 更大的节点,直接插入
        insertUpThenDownFunc(point, min, max).add(values);
    }

    /**
     * 普通的插入, 注意: 当直接使用普通插入后,该树将退化为普通的 TreeMap
     * 
     * @param min    最小值
     * @param max    最大值
     * @param values 值
     */
    public void normalPut(K min, K max, Collection<V> values) {
        if (min == null || max == null) {
            throw new IllegalArgumentException("range tree map dose not allow null min or null max");
        }
        insertDownFunc(root, min, max).add(values);
    }

    /**
     * 搜索插入,从给定的根节点开始向下搜索插入
     * 
     * @param x   给定根节点
     * @param min 区间最小值
     * @param max 区间最大值
     * @return 插入后的节点
     */
    private Node<K, V> insertDownFunc(Node<K, V> x, K min, K max) {
        Node<K, V> p = NIL;
        int cmp;
        while (x != NIL) {
            p = x;
            cmp = x.min.compareTo(min);
            if (cmp > 0) {
                x = x.left;
            } else if (cmp < 0) {
                x = x.right;
            } else {
                return x;
            }
        }
        Node<K, V> z = new Node<>(RED, min, max, p, NIL, NIL);
        if (p == NIL) {
            this.root = z;
        } else if (z.min.compareTo(p.min) < 0) {
            p.left = z;
        } else {
            p.right = z;
        }
        insertFix(z);
        nodeCount++;
        return z;
    }

    /**
     * 搜索插入,相对于上面的普通搜索插入,不认为给定节点是根节点,所以先向上搜索,再向下搜索插入.
     * <p>
     * 在最好的情况下,只需要搜索一次;在最坏的情况下,退化到 2*logN.
     * <p>
     * 当给定节点不是待插入节点的前驱/后继节点时,不要使用该方法.
     * 
     * @param x   给定节点
     * @param min 区间最小值
     * @param max 区间最大值
     * @return 插入后的节点
     */
    private Node<K, V> insertUpThenDownFunc(Node<K, V> x, K min, K max) {
        Node<K, V> p = NIL;
        int cmp;
        while (x != NIL) {
            if (x.parent == NIL) {
                break;
            }
            cmp = x.parent.min.compareTo(min);
            if (cmp > 0) {
                if (x.parent.right == x) {
                    x = x.parent;
                } else {
                    break;
                }
            } else if (cmp < 0) {
                if (x.parent.left == x) {
                    x = x.parent;
                } else {
                    break;
                }
            } else {
                return x;
            }
        }
        return insertDownFunc(x, min, max);
    }

    /**
     * 插入修正
     * 
     * @param z 待修正的节点
     */
    private void insertFix(Node<K, V> z) {
        Node<K, V> p;
        while (z.parent.color == RED) {
            if (z.parent == z.parent.parent.left) {
                p = z.parent.parent.right;
                if (p.color == RED) {
                    z.parent.color = BLACK;
                    p.color = BLACK;
                    z.parent.parent.color = RED;
                    z = z.parent.parent;
                } else {
                    if (z == z.parent.right) {
                        z = z.parent;
                        leftRotate(z);
                    }
                    z.parent.color = BLACK;
                    z.parent.parent.color = RED;
                    rightRotate(z.parent.parent);
                }
            } else {
                p = z.parent.parent.left;
                if (p.color == RED) {
                    z.parent.color = BLACK;
                    p.color = BLACK;
                    z.parent.parent.color = RED;
                    z = z.parent.parent;
                } else {
                    if (z == z.parent.left) {
                        z = z.parent;
                        rightRotate(z);
                    }
                    z.parent.color = BLACK;
                    z.parent.parent.color = RED;
                    leftRotate(z.parent.parent);
                }
            }
        }
        root.color = BLACK;
    }

    /**
     * 左旋
     * 
     * @param x 待左旋的节点
     */
    private void leftRotate(Node<K, V> x) {
        Node<K, V> p = x.right;
        x.right = p.left;
        if (p.left != NIL) {
            p.left.parent = x;
        }
        p.parent = x.parent;
        if (x.parent == NIL) {
            root = p;
        } else if (x == x.parent.left) {
            x.parent.left = p;
        } else {
            x.parent.right = p;
        }
        p.left = x;
        x.parent = p;
    }

    /**
     * 右旋
     * 
     * @param x 待右旋的节点
     */
    private void rightRotate(Node<K, V> x) {
        Node<K, V> p = x.left;
        x.left = p.right;
        if (p.right != NIL) {
            p.right.parent = x;
        }
        p.parent = x.parent;
        if (x.parent == NIL) {
            root = p;
        } else if (x == x.parent.left) {
            x.parent.left = p;
        } else {
            x.parent.right = p;
        }
        p.right = x;
        x.parent = p;
    }

    public void delete(K min) {
        Node<K, V> node;
        if (key == null || (node = search(min) == NIL)) {
            return;
        }
        deleteNode(node);
    }

    private void deleteNode(Node<K, V> node) {
        if (node == NIL) {
            return;
        }
        Node<K, V> x;
        Node<K, V> y = node;
        boolean color = y.color;
        if (node.left == NIL) {
            x = node.right;
            transplant(node, node.right);
        } else if (node.right == NIL) {
            x = node.left;
            transplant(node, node.left);
        } else {
            y = successor(node.right);
            color = y.color;
            x = y.right;
            if (y.parent == node) {
                x.parent = y;
            } else {
                transplant(y, y.right);
                y.right = node.right;
                y.right.parent = y;
            }
            transplant(node, y);
            y.left = node.left;
            y.left.parent = y;
            y.color = node.color;
        }
        if (color == BLACK) {
            deleteFix(x);
        }
        nodeCount--;
    }

    private void transplant(Node<K, V> u, Node<K, V> v) {
        if (u.parent == NIL) {
            root = v;
        } else if (u == u.parent.left) {
            u.parent.left = v;
        } else
            u.parent.right = v;
        v.parent = u.parent;
    }

    /**
     * 删除修正
     * 
     * @param x 待修正的节点
     */
    private void deleteFix(Node<K, V> x) {
        while (x != root && x.color == BLACK) {
            if (x == x.parent.left) {
                Node<K, V> w = x.parent.right;
                if (w.color == RED) {
                    w.color = BLACK;
                    x.parent.color = RED;
                    leftRotate(x.parent);
                    w = x.parent.right;
                }
                if (w.left.color == BLACK && w.right.color == BLACK) {
                    w.color = RED;
                    x = x.parent;
                    continue;
                } else if (w.right.color == BLACK) {
                    w.left.color = BLACK;
                    w.color = RED;
                    rightRotate(w);
                    w = x.parent.right;
                }
                if (w.right.color == RED) {
                    w.color = x.parent.color;
                    x.parent.color = BLACK;
                    w.right.color = BLACK;
                    leftRotate(x.parent);
                    x = root;
                }
            } else {
                Node<K, V> w = (x.parent.left);
                if (w.color == RED) {
                    w.color = BLACK;
                    x.parent.color = RED;
                    rightRotate(x.parent);
                    w = x.parent.left;
                }
                if (w.right.color == BLACK && w.left.color == BLACK) {
                    w.color = RED;
                    x = x.parent;
                    continue;
                } else if (w.left.color == BLACK) {
                    w.right.color = BLACK;
                    w.color = RED;
                    leftRotate(w);
                    w = x.parent.left;
                }
                if (w.left.color == RED) {
                    w.color = x.parent.color;
                    x.parent.color = BLACK;
                    w.left.color = BLACK;
                    rightRotate(x.parent);
                    x = root;
                }
            }
        }
        x.color = BLACK;
    }

    /**
     * 获取指定节点的后继节点
     * 
     * @param node 指定节点
     * @return 后继节点
     */
    private Node<K, V> successor(Node<K, V> node) {
        if (node == NIL) {
            return NIL;
        } else if (node.right != NIL) {
            Node<K, V> p = node.right;
            while (p.left != NIL) {
                p = p.left;
            }
            return p;
        } else {
            Node<K, V> p = node.parent;
            Node<K, V> ch = node;
            while (p != NIL && ch == p.right) {
                ch = p;
                p = p.parent;
            }
            return p;
        }
    }

    /**
     * 获取指定节点的前驱节点
     * 
     * @param node 指定节点
     * @return 前驱节点
     */
    private Node<K, V> predecessor(Node<K, V> node) {
        if (node == NIL) {
            return NIL;
        } else if (node.left != NIL) {
            Node<K, V> p = node.left;
            while (p.right != NIL) {
                p = p.right;
            }
            return p;
        } else {
            Node<K, V> p = node.parent;
            Node<K, V> ch = node;
            while (p != NIL && ch == p.left) {
                ch = p;
                p = p.parent;
            }
            return p;
        }
    }

    public boolean contains(K min) {
        return search(min) != NIL;
    }

    /**
     * 根据指定 key 查找节点
     * 
     * @param min 指定的 key
     * @return 对应的节点
     */
    private Node<K, V> search(K min) {
        if (min == null) {
            return NIL;
        }
        Node<K, V> node = root;
        while (node != NIL) {
            int cmp = min.compareTo(node.min);
            if (cmp > 0) {
                node = node.right;
            } else if (cmp < 0) {
                node = node.left;
            } else {
                return node;
            }
        }
        return NIL;
    }

    private Node<K, V> getMinNode() {
        Node<K, V> node = root;
        while (node.left != NIL) {
            node = node.left;
        }
        return node;
    }

    private Node<K, V> getMaxNode() {
        Node<K, V> node = root;
        while (node.right != NIL) {
            node = node.right;
        }
        return node;
    }

    private Node<K, V> getLowerNode(Node<K, V> node, K min) {
        if (min == null) {
            return NIL;
        }
        while (node != NIL) {
            if (min.compareTo(node.min) > 0) {
                if (node.right != NIL) {
                    node = node.right;
                } else {
                    return node;
                }
            } else {
                if (node.left != NIL) {
                    node = node.left;
                } else {
                    Node<K, V> parent = node.parent;
                    Node<K, V> ch = node;
                    while (parent != NIL && ch == parent.left) {
                        ch = parent;
                        parent = parent.parent;
                    }
                    return parent;
                }
            }
        }
        return NIL;
    }

    private void foreach(Consumer<Node<K, V>> consumer) {
        Stack<Node<K, V>> stack = new Stack<>();
        stack.push(root);
        Node<K, V> trav = root;
        while (root != NIL && !stack.isEmpty()) {
            while (trav != NIL && trav.left != NIL) {
                stack.push(trav.left);
                trav = trav.left;
            }

            Node<K, V> node = stack.pop();
            consumer.accept(node);

            if (node.right != NIL) {
                stack.push(node.right);
                trav = node.right;
            }
        }
    }

    private static final boolean RED = true;
    private static final boolean BLACK = false;

    /**
     * 红黑树子节点
     */
    public static class Node<K extends Comparable<K>, V> {
        /**
         * 节点颜色
         */
        private boolean color;
        /**
         * 区间最小值,且在树中排序也按照最小值排序
         */
        private K min;
        /**
         * 区间最大值
         */
        private K max;
        /**
         * 该区间对应的存储数据
         */
        private Set<V> values;
        /**
         * 左孩子,右孩子,父节点
         */
        private Node<K, V> left, right, parent;

        Node(boolean color, K min, K max, Node<K, V> parent, Node<K, V> left, Node<K, V> right) {
            if (parent == null && left == null && right == null) {
                parent = this;
                left = this;
                right = this;
            }
        }

        private Node<K, V> add(Collection<V> values) {
            if (values != null) {
                if (this.values == null) {
                    this.values = new HashSet<>();
                }
            }
            return this;
        }

        @Override
        public String toString() {
            return min + "-" + max + ":" + values;
        }
    }
}

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值