线段树(区间树)
一、线段树基础
- 为什么使用线段树?对于有一类问题,我们关心的是线段(或者区间)
最经典的线段树问题:区间染色
- 从4~9绘制成橙色
- 将7~15的部分绘制成绿色(原来橙色的部分区间被绿色覆盖了)
- 对1~5的地方绘制成蓝色 ,6~12的部分绘制成红色
问:经过m次操作后,我们可以看见多少种颜色?我们可以在[i,j]区间内看见多少种颜色?
-
另一类经典问题:区间查询
上述问题可以通过数组和线段树解决,但是线段树的时间复杂度低
-
线段树:同样也是一个二叉树,只是它的每一个节点表示的都是一个区间内的所有信息。(以求和为例,A[4…7]存的就是A[4]~A[7]的和)
如果根节点有10个元素时:
平衡二叉树:对于整棵树来说,最大的叶子节点的深度和最小的叶子节点的深度最多的差距只能为1. -
如果区间有n个元素,数组表示需要有多少节点
如果区间有n个元素,数组表示需要有多少节点?需要4n的空间
我们的线段树不考虑添加元素,即区间固定。使用4n的静态空间即可
二、线段树的创建
- 由数组表示的线段树应具备几个基础功能:获取尺寸、根据索引得到相应值、返回当前索引的左右孩子索引。代码实现:
public class SegmentTree<E> {
private E[] tree;
private E[] data;
public SegmentTree(E[] arr){
data = (E[])new Object[arr.length];
for(int i = 0 ; i < arr.length ; i ++)
data[i] = arr[i];
tree = (E[])new Object[4 * arr.length]; //4n的空间
}
public int getSize(){
return data.length;
}
public E get(int index){
if(index < 0 || index >= data.length)
throw new IllegalArgumentException("Index is illegal.");
return data[index];
}
// 返回完全二叉树的数组表示中,一个索引所表示的元素的左孩子节点的索引
private int leftChild(int index){
return 2*index + 1;
}
// 返回完全二叉树的数组表示中,一个索引所表示的元素的右孩子节点的索引
private int rightChild(int index){
return 2*index + 2;
}
}
- 创建线段树:使用递归操作,自定义Merge接口的作用是:在合并小线段的相应信息来获取更大的信息时,为了使得当前的操作不仅仅只能够实现固定的方法,可以通过实现一个融合器方法并且在创建线段树的时候传递进去一个融合器方法来实现能够自定义线段树合并方法的目的。
- 融合器接口实现:
public interface Merge<E> {
E merge(E a,E b);
}
- 创建线段树:
public class SegmentTree<E>{
private E[] data;
private E[] tree;
private Merge<E> merger;
public SegmentTree(E[] arr,Merge<E> merger) {
this.merger = merger;
data = (E[])new Object[arr.length];
for(int i=0;i<arr.length;i++) {
data[i] = arr[i];
}
tree = (E[])new Object[4*arr.length]; //4n的空间
buildSegmentTree(0,0,arr.length-1);
}
//在treeIndex的位置创建表示区间[l...r]的线段树
private void buildSegmentTree(int treeIndex,int l,int r) {
if(l==r) {
tree[treeIndex] = data[l];
return;
}
int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
int mid = l+(r-l)/2; //等价于(l+r)/2.但是l+r可能存在整型溢出的问题
buildSegmentTree(leftTreeIndex,l,mid);
buildSegmentTree(rightTreeIndex,mid+1,r);
//综合两个线段相应的信息来得到更大的信息。使用merge接口使得融合的方法可以由用户自己传入
tree[treeIndex] = merger.merge(tree[leftTreeIndex], tree[rightTreeIndex]);
}
public int getSize() {
return data.length;
}
public E get(int index) {
if(index<0 || index>=data.length) {
throw new IllegalArgumentException("错误");
}
return data[index];
}
//返回完全二叉树的数组表示中,一个索引所表示的元素的左孩子节点的索引
private int leftChild(int index) {
return 2*index+1;
}
//返回完全二叉树的数组表示中,一个索引所表示的元素的右孩子节点的索引
private int rightChild(int index) {
return 2*index+2;
}
@Override
public String toString() {
StringBuilder res = new StringBuilder();
res.append('[');
for(int i=0;i<tree.length;i++) {
if(tree[i]!=null) {
res.append(tree[i]);
}else {
res.append("null");
}
if(i!=tree.length-1) {
res.append(",");
}
}
res.append(']');
return res.toString();
}
}
- 测试
public class Main {
public static void main(String[] args) {
Integer[] nums = {-2,0,3,-5,2,1};
SegmentTree<Integer> segTree = new SegmentTree<Integer>(nums,(a,b)->a+b);//lamda表达式
System.out.println(segTree);
}
}
线段树的查询操作:如查询[2,5]
- 添加查询方法
代码实现
- 添加公有的查询方法及私有的查询方法
//返回区间[queryL,queryR]的值
public E query(int queryL,int queryR) {
if(queryL<0||queryR>data.length || queryL>queryR) {
throw new IllegalArgumentException("index is illegal");
}
return query(0,0,data.length-1,queryL,queryR);
}
//在以treeID为根的线段树中[l...r]的范围里,搜索区间[queryL...queryR]的值
private E query(int treeIndex,int l,int r,int queryL,int queryR) {
if(l == queryL && r==queryR) {
return tree[treeIndex];
}
int mid = l + (r - l)/2;
int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
if(queryL>=mid+1) {
return query(rightTreeIndex,mid+1,r,queryL,queryR); //到右子树中进行查找
}else if(queryR<=mid){
return query(leftTreeIndex,l,mid,queryL,queryR); //到左子树中进行查找
}
//如果查找的区间既不全部在左孩子区间,又不全在右孩子区间
E leftResult = query(leftTreeIndex,l,mid,queryL,mid);
E rightResult = query(rightTreeIndex,mid+1,r,mid+1,queryR);
return merger.merge(leftResult, rightResult);
}
三、LeetCode题目
LeetCode 303:区域和检索-不可变
- 利用线段树进行解答
class NumArray {
private SegmentTree<Integer> segmentTree;
public NumArray(int[] nums) {
if(nums.length > 0) {
Integer[] data = new Integer[nums.length];
for(int i=0;i<nums.length;i++) {
data[i] = nums[i];
}
segmentTree = new SegmentTree<>(data,(a,b)->a+b);
}
}
public int sumRange(int i, int j) {
if(segmentTree == null) {
throw new IllegalArgumentException("Segment Tree is null");
}
return segmentTree.query(i, j);
}
}
- 利用数组进行解答
由于数组不可修改,因此可以构建一个辅助数组,这个辅助数组存储的是前i个元素的和。这种构建方式非常经典!!需要提高注意
public class NumArray2 {
private int[] sum; //sum[i]存储前i个元素的和,sum[0] = 0
//sum[i]存储nums[0....i-1]的和
public NumArray2(int[] nums) {
sum = new int[nums.length +1];
sum[0] = 0;
for(int i=1;i<sum.length;i++) {
sum[i] = sum[i-1]+nums[i-1];
}
}
public int sumRange(int i, int j) {
return sum[j+1] - sum[i];
}
}
当数组的内容可修改时:
LeetCode 307:区域和检索 - 数组可修改
- 采用数组实现
public class NumArray {
private int[] sum; //sum[i]存储前i个元素的和,sum[0] = 0
//sum[i]存储nums[0....i-1]的和
private int[] data;
public NumArray(int[] nums) {
data =new int[nums.length];
for(int i=0;i<nums.length;i++) {
data[i] = nums[i];
}
sum = new int[nums.length +1];
sum[0] = 0;
for(int i=1;i<sum.length;i++) {
sum[i] = sum[i-1]+nums[i-1];
}
}
public void update(int index, int val) {
data[index] = val;
for(int i = index+1;i<sum.length;i++) {
sum[i] = sum[i-1]+data[i-1];
}
}
public int sumRange(int i, int j) {
return sum[j+1] - sum[i];
}
}
此时执行效率很慢,甚至出现运行时间超出限制报错的问题。原因在于:利用数组进行更新时,每次更新的操作都是O(n)级别的操作。如果进行了n次更新,那么时间复杂度就成了O(n*n)的情况。因此效率大大降低。这时候应该利用线段树实现更新的操作。
- 线段树实现更新的操作
添加set方法:
//将index位置的值,更新为e
public void set(int index,E e) {
if(index<0||index>data.length) {
throw new IllegalArgumentException("index is illegal");
}
data[index] = e;
set(0,0,data.length-1,index,e);
}
//针对treeIndex为根的线段树,对于节点的l到r区间的值,将index的值更新为e
private void set(int treeIndex,int l,int r,int index,E e) {
//这个更新操作从根节点开始,一直找到index节点所在的位置后进行更新。这个高度是logn级别的。因此更新为logn复杂度
if(l == r) {
tree[treeIndex] = e;
return;
}
int mid = l + (r-l)/2;
int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
if(index>=mid+1) {
set(rightTreeIndex,mid+1,r,index,e);
}else {
set(leftTreeIndex,l,mid,index,e);
}
//我们在对每一个节点进行更新的时候,节点相应的父节点中的值也需要进行更新
tree[treeIndex] = merger.merge(tree[leftTreeIndex], tree[rightTreeIndex]);
}
相应的NumArray所做的改变
class NumArray {
private SegmentTree<Integer> segmentTree;
public NumArray(int[] nums) {
if(nums.length > 0) {
Integer[] data = new Integer[nums.length];
for(int i=0;i<nums.length;i++) {
data[i] = nums[i];
}
segmentTree = new SegmentTree<>(data,(a,b)->a+b);
}
}
public void update(int index, int val) {
if(segmentTree == null) {
throw new IllegalArgumentException("Segment Tree is null");
}
segmentTree.set(index, val);
}
public int sumRange(int i, int j) {
if(segmentTree == null) {
throw new IllegalArgumentException("Segment Tree is null");
}
return segmentTree.query(i, j);
}
}