线段树 区间树 简介+实现

简介

适用于: 要考虑区间相关的统计信息, 且数据是动态更新的
使用线段树查询或者更新时间复杂度都是 O(logn)

线段树都是平衡二叉树, 但不一定是完全二叉树

比如要存储下面这个数组data, 我们的目的是要求这个数组任意区间的和
在这里插入图片描述
存储结构是这样的
在这里插入图片描述
因为这是平衡二叉树, 所以我把它保存在数组tree中, 其中

  • TreeIndex: 在tree数组中的索引
  • l: 在data数组中的左边界
  • r: 在data数组中的右边界
  • TreeValue: 可以是这个data数组的[l…r]区间的和, 最大值, 最小值, 看用户需求, 这里是 和
    l: 0 r: 5 表示data数组中索引0到5的和是-6

实现

Merge.java

/**
 * 用于定义融合的规则, 比如线段树中两个区间的融合, 可以是两个区间的最大,
 * 两个区间的最小, 两个区间的和, 由用户传入的规则定义
 * @param <E>
 */
public interface Merge<E> {
    E merge(E a, E b);
}

SegmentTree.java

/**
 * 初始化后不添加元素, 即区间固定, 只能更新
 */
public class SegmentTree<E> {

    private E[] data;
    private E[] tree;
    private Merge<E> merger;  // "融合器"

    public SegmentTree(E[] arr, Merge<E> merger){
        this.merger = merger;  // 用户传来融合的规则

        data = (E[])new Object[arr.length];
        System.arraycopy(arr, 0, data, 0, arr.length);

        tree = (E[])new Object[4 * arr.length];  // 估计值, n个元素需要4n的空间
        buildSegmentTree(0, 0, data.length-1);
    }

    /**
     * 在treeIndex的位置创建表示区间[l...r]的线段树 O(4n)
     * @param treeIndex
     * @param l
     * @param r
     */
    private void buildSegmentTree(int treeIndex, int l, int r) {
        if(l == r){
            tree[treeIndex] = data[l];
            return;
        }

        int leftTreeIndex = leftChild(treeIndex);
        int rightTreeIndex = rightChild(treeIndex);

        int mid = l + (r - l)/2;  // 养成习惯, (l+r)/2可能会溢出

        buildSegmentTree(leftTreeIndex, l, mid);  // 在左孩子的位置创建原来左一半的区间树
        buildSegmentTree(rightTreeIndex, mid+1, r);  // 在右孩子的位置创建原来

        tree[treeIndex] = merger.merge(tree[leftTreeIndex], tree[rightTreeIndex]);  // 父节点为两个子节点的融合, 可以是两个子节点的最大, 最小, 和
    }

    public int getSize(){
        return data.length;
    }

    public E get(int index){
        if(index < 0 || index>data.length)
            throw new IllegalArgumentException("Index is illegal");
        return data[index];
    }

    /**
     * 返回完全二叉树的数组表示中,一个索引所表示的元素的左孩子节点的索引
     * @param index
     * @return
     */
    private int leftChild(int index){
        return 2*index + 1;
    }

    /**
     * 返回完全二叉树的数组表示中,一个索引所表示的元素的右孩子节点的索引
     * @param index
     * @return
     */
    private int rightChild(int index){
        return 2*index + 2;
    }

    /**
     * 查询 O(log n)
     * @param queryL
     * @param queryR
     * @return 返回区间[queryL, queryR]的值
     */
    public E query(int queryL, int queryR){
        if(queryL < 0 || queryR >= data.length || queryL > queryR){
            throw new IllegalArgumentException("Index is illegal. ");
        }
        return query(0, 0, data.length-1, queryL, queryR);

    }

    /**
     * 在以treeIndex为根的线段树中[l...r]的范围里, 搜索区间[queryL...queryR]的值
     * @param treeIndex
     * @param l treeIndex所代表的区间(其左右子树的区间之和)的左边界, 即treeIndex的左子树的左边界
     * @param r treeIndex所代表的区间的右边界, 即treeIndex的右子树的右边界
     * @param queryL
     * @param queryR
     * @return
     */
    private E query(int treeIndex, int l, int r, int queryL, int queryR) {
        // 要查询的区间刚好与treeIndex所代表的区间[l...r]重合
        if(l == queryL && r == queryR){
            return tree[treeIndex];
        }

        int mid = l + (r - l) / 2;
        int leftTreeIndex = leftChild(treeIndex);
        int rightTreeIndex = rightChild(treeIndex);
        // 如果查询区间完全在treeIndex的右子树表示的区间
        if(queryL >= mid){
            return query(rightTreeIndex, mid+1, r, queryL, queryR);
        }
        // 如果查询区间完全在treeIndex的左子树表示的区间
        else if(queryR <= mid){
            return query(leftTreeIndex, l, mid, queryL, queryR);
        }
        // 如果查询区间一部分在treeIndex的左子树区间, 一部分在右子树区间
        // [queryL...mid]在左子树的区间
        // [mid+1...queryR]在右子树的区间
        E leftResult = query(leftTreeIndex, l, mid, queryL, mid);
        E rightResult = query(rightTreeIndex, mid+1, r, mid+1, queryR);
        return merger.merge(leftResult, rightResult);
    }

    /**
     * 将index位置的值更新为e
     * @param index
     * @param e
     */
    public void set(int index, E e){
        if(index < 0 || index >= data.length){
            throw new IllegalArgumentException("Index is illegal. ");
        }

        data[index] = e;
        set(0, 0, data.length-1, index, e);
    }

    /**
     * 在以treeIndex为根的线段树中更新index的值为e
     * 找到index结点(叶子节点)所在的高度最大是树的深度, 所以是O(log n)
     * @param treeIndex
     * @param l
     * @param r
     * @param e
     */
    private void set(int treeIndex, int l, int r, int index, E e) {
        if(l == r){
            tree[treeIndex] = e;
            return;
        }
        int leftTreeIndex = leftChild(treeIndex);
        int rightTreeIndex = rightChild(treeIndex);
        int mid = l + (l - r) / 2;

        // 如果index在左子树的区间, 继续找, 知道找到
        if(index <= mid){
            set(leftTreeIndex, l, mid, index, e);
        }
        // 如果index在右子树的区间
        else{
            set(rightTreeIndex, mid+1, r, index, e);
        }
       tree[treeIndex] = merger.merge(tree[leftTreeIndex], tree[rightTreeIndex]);
    }

    /**
     *
     * @return 层序遍历线段树
     */
    @Override
    public String toString(){
        StringBuilder res = new StringBuilder();
        res.append('[');
        for(int i=0; i<tree.length; i++){
            if(tree[i] != null){
                res.append(tree[i]);
            }
            else{
                res.append("null");
            }

            if(i != tree.length-1){
                res.append(", ");
            }
        }
        res.append(']');
        return res.toString();
    }

    public static void main(String[] args) {
        Integer[] nums = {-3, 0, 4, -6, 2, -1};
//        Integer[] nums = {-3, 0, 4};

//        // 匿名函数传入融合规则
//        SegmentTree<Integer> st = new SegmentTree<>(nums, new Merge<Integer>() {
//            @Override
//            public Integer merge(Integer a, Integer b) {
//                return a+b;
//            }
//        });

        SegmentTree<Integer> st = new SegmentTree<>(nums, (a, b)->a+b);  // 区间和的线段树
//        SegmentTree<Integer> st = new SegmentTree<>(nums, (a, b)->max(a, b));  // 区间最大的线段树
//        System.out.println(st);

        System.out.println(st.query(3, 5));
        System.out.println(st.query(0, 5));
    }
}

其他类型的线段树

  • 二维线段树(上面的是一维的)
  • 动态线段树: 链式存储, 而不是用数组, 可以节省空间, 当data很大时, 可以不用一开始就开辟所有空间并初始化线段树, 可以在用到的时候再去生成线段树对应的子树(区间)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值