简介
适用于: 要考虑区间相关的统计信息, 且数据是动态更新的
使用线段树查询或者更新时间复杂度都是 O(logn)
线段树都是平衡二叉树, 但不一定是完全二叉树
比如要存储下面这个数组data, 我们的目的是要求这个数组任意区间的和
存储结构是这样的
因为这是平衡二叉树, 所以我把它保存在数组tree中, 其中
- TreeIndex: 在tree数组中的索引
- l: 在data数组中的左边界
- r: 在data数组中的右边界
- TreeValue: 可以是这个data数组的[l…r]区间的和, 最大值, 最小值, 看用户需求, 这里是 和
l: 0 r: 5 表示data数组中索引0到5的和是-6
实现
Merge.java
/**
* 用于定义融合的规则, 比如线段树中两个区间的融合, 可以是两个区间的最大,
* 两个区间的最小, 两个区间的和, 由用户传入的规则定义
* @param <E>
*/
public interface Merge<E> {
E merge(E a, E b);
}
SegmentTree.java
/**
* 初始化后不添加元素, 即区间固定, 只能更新
*/
public class SegmentTree<E> {
private E[] data;
private E[] tree;
private Merge<E> merger; // "融合器"
public SegmentTree(E[] arr, Merge<E> merger){
this.merger = merger; // 用户传来融合的规则
data = (E[])new Object[arr.length];
System.arraycopy(arr, 0, data, 0, arr.length);
tree = (E[])new Object[4 * arr.length]; // 估计值, n个元素需要4n的空间
buildSegmentTree(0, 0, data.length-1);
}
/**
* 在treeIndex的位置创建表示区间[l...r]的线段树 O(4n)
* @param treeIndex
* @param l
* @param r
*/
private void buildSegmentTree(int treeIndex, int l, int r) {
if(l == r){
tree[treeIndex] = data[l];
return;
}
int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
int mid = l + (r - l)/2; // 养成习惯, (l+r)/2可能会溢出
buildSegmentTree(leftTreeIndex, l, mid); // 在左孩子的位置创建原来左一半的区间树
buildSegmentTree(rightTreeIndex, mid+1, r); // 在右孩子的位置创建原来
tree[treeIndex] = merger.merge(tree[leftTreeIndex], tree[rightTreeIndex]); // 父节点为两个子节点的融合, 可以是两个子节点的最大, 最小, 和
}
public int getSize(){
return data.length;
}
public E get(int index){
if(index < 0 || index>data.length)
throw new IllegalArgumentException("Index is illegal");
return data[index];
}
/**
* 返回完全二叉树的数组表示中,一个索引所表示的元素的左孩子节点的索引
* @param index
* @return
*/
private int leftChild(int index){
return 2*index + 1;
}
/**
* 返回完全二叉树的数组表示中,一个索引所表示的元素的右孩子节点的索引
* @param index
* @return
*/
private int rightChild(int index){
return 2*index + 2;
}
/**
* 查询 O(log n)
* @param queryL
* @param queryR
* @return 返回区间[queryL, queryR]的值
*/
public E query(int queryL, int queryR){
if(queryL < 0 || queryR >= data.length || queryL > queryR){
throw new IllegalArgumentException("Index is illegal. ");
}
return query(0, 0, data.length-1, queryL, queryR);
}
/**
* 在以treeIndex为根的线段树中[l...r]的范围里, 搜索区间[queryL...queryR]的值
* @param treeIndex
* @param l treeIndex所代表的区间(其左右子树的区间之和)的左边界, 即treeIndex的左子树的左边界
* @param r treeIndex所代表的区间的右边界, 即treeIndex的右子树的右边界
* @param queryL
* @param queryR
* @return
*/
private E query(int treeIndex, int l, int r, int queryL, int queryR) {
// 要查询的区间刚好与treeIndex所代表的区间[l...r]重合
if(l == queryL && r == queryR){
return tree[treeIndex];
}
int mid = l + (r - l) / 2;
int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
// 如果查询区间完全在treeIndex的右子树表示的区间
if(queryL >= mid){
return query(rightTreeIndex, mid+1, r, queryL, queryR);
}
// 如果查询区间完全在treeIndex的左子树表示的区间
else if(queryR <= mid){
return query(leftTreeIndex, l, mid, queryL, queryR);
}
// 如果查询区间一部分在treeIndex的左子树区间, 一部分在右子树区间
// [queryL...mid]在左子树的区间
// [mid+1...queryR]在右子树的区间
E leftResult = query(leftTreeIndex, l, mid, queryL, mid);
E rightResult = query(rightTreeIndex, mid+1, r, mid+1, queryR);
return merger.merge(leftResult, rightResult);
}
/**
* 将index位置的值更新为e
* @param index
* @param e
*/
public void set(int index, E e){
if(index < 0 || index >= data.length){
throw new IllegalArgumentException("Index is illegal. ");
}
data[index] = e;
set(0, 0, data.length-1, index, e);
}
/**
* 在以treeIndex为根的线段树中更新index的值为e
* 找到index结点(叶子节点)所在的高度最大是树的深度, 所以是O(log n)
* @param treeIndex
* @param l
* @param r
* @param e
*/
private void set(int treeIndex, int l, int r, int index, E e) {
if(l == r){
tree[treeIndex] = e;
return;
}
int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
int mid = l + (l - r) / 2;
// 如果index在左子树的区间, 继续找, 知道找到
if(index <= mid){
set(leftTreeIndex, l, mid, index, e);
}
// 如果index在右子树的区间
else{
set(rightTreeIndex, mid+1, r, index, e);
}
tree[treeIndex] = merger.merge(tree[leftTreeIndex], tree[rightTreeIndex]);
}
/**
*
* @return 层序遍历线段树
*/
@Override
public String toString(){
StringBuilder res = new StringBuilder();
res.append('[');
for(int i=0; i<tree.length; i++){
if(tree[i] != null){
res.append(tree[i]);
}
else{
res.append("null");
}
if(i != tree.length-1){
res.append(", ");
}
}
res.append(']');
return res.toString();
}
public static void main(String[] args) {
Integer[] nums = {-3, 0, 4, -6, 2, -1};
// Integer[] nums = {-3, 0, 4};
// // 匿名函数传入融合规则
// SegmentTree<Integer> st = new SegmentTree<>(nums, new Merge<Integer>() {
// @Override
// public Integer merge(Integer a, Integer b) {
// return a+b;
// }
// });
SegmentTree<Integer> st = new SegmentTree<>(nums, (a, b)->a+b); // 区间和的线段树
// SegmentTree<Integer> st = new SegmentTree<>(nums, (a, b)->max(a, b)); // 区间最大的线段树
// System.out.println(st);
System.out.println(st.query(3, 5));
System.out.println(st.query(0, 5));
}
}
其他类型的线段树
- 二维线段树(上面的是一维的)
- 动态线段树: 链式存储, 而不是用数组, 可以节省空间, 当data很大时, 可以不用一开始就开辟所有空间并初始化线段树, 可以在用到的时候再去生成线段树对应的子树(区间)