线段树特点
特点:平衡二叉树,使用数组表示数的结构。
- 通过给定的数组对象,获取需要生成线段树的高度(2的h-1次方大于数组长度,取h的最小值);
- 从而得到线段树所需数组的长度(2的h次方减1);通过数组索引构建线段树。
自定义接口
package com.company.segment;
/**
* @Author: wenhua
* @CreateTime: 2023-01-12 19:59
*/
public interface Merger<T> {
T merger(T t1,T t2);
}
自定义线段树
package com.company.segment;
/**
* 自定义线段树
* 特点:平衡二叉树,使用数组表示数的结构。
* 通过给定的数组对象,获取需要生成线段树的高度(2的h-1次方大于数组长度,取h的最小值),
* 从而得到线段树所需数组的长度(2的h次方减1)。
* 通过数组索引构建线段树。
*
* @param <T> 泛型
* @Author: wenhua
* @CreateTime: 2023-01-12 15:32
*/
public class SegmentTree<T> {
Merger<T> merger;
public SegmentTree(Merger merger) {
this.merger = merger;
}
/**
* 创建线段树数组对象
*
* @param arr
* @return
*/
public T[] buildArray(T[] arr) {
if (arr == null || arr.length == 0) {
return null;
}
// 通过已知数组,计算线段树高度
int height = (int) Math.ceil(Math.log(arr.length) / Math.log(2) + 1);
// 创建线段树(按照满二叉树结点个数,即最坏情况)数组对象
return (T[]) new Integer[(int) (Math.pow(2, height) - 1)];
}
/**
* 求和线段树
*
* @param arr 初始数组
* @return 返回线段树对象
*/
public T[] sumSegment(T[] arr) {
// 构建线段树
return buildSumSegment(arr, buildArray(arr), 0, arr.length - 1, 0);
}
/**
* 最大值线段树
*
* @param arr 初始数组
* @return 返回线段树对象
*/
public T[] maxSegment(T[] arr) {
// 构建线段树
return buildMaxSegment(arr, buildArray(arr), 0, arr.length - 1, 0);
}
/**
* 构建求和线段树
*
* @param arr 原始数组
* @param segmentTree 线段树数组对象
* @param start 初始索引
* @param end 结束索引
* @param index 当前索引
* @return 返回线段树对象
*/
private T[] buildSumSegment(T[] arr, T[] segmentTree, int start, int end, int index) {
// 递归结束条件,当初始索引等于结束索引时,存储数据并返回线段树
if (start == end) {
segmentTree[index] = arr[start];
return segmentTree;
}
/**
* 如果不相等,则将索引分开
* 将中间索引为start+(end-start)/2,因为(start+end)/2有可能超出索引范围
* 左右索引分别遵循二叉树的性质,
* 通过新的索引分别构建左右线段树
*/
int mid = start + (end - start) / 2;
int leftIndex = index * 2 + 1;
int rightIndex = leftIndex + 1;
buildSumSegment(arr, segmentTree, start, mid, leftIndex);
buildSumSegment(arr, segmentTree, mid + 1, end, rightIndex);
// 求得左右子树的和赋值给当前索引对应位置
segmentTree[index] = merger.merger(segmentTree[leftIndex], segmentTree[rightIndex]);
return segmentTree;
}
/**
* 构建最大值线段树
*
* @param arr 原始数组
* @param segmentTree 线段树数组对象
* @param start 初始索引
* @param end 结束索引
* @param index 当前索引
* @return 返回线段树对象
*/
private T[] buildMaxSegment(T[] arr, T[] segmentTree, int start, int end, int index) {
// 递归结束条件,当初始索引等于结束索引时,存储数据并返回线段树
if (start == end) {
segmentTree[index] = arr[start];
return segmentTree;
}
int mid = start + (end - start) / 2;
int leftIndex = index * 2 + 1;
int rightIndex = leftIndex + 1;
buildMaxSegment(arr, segmentTree, start, mid, leftIndex);
buildMaxSegment(arr, segmentTree, mid + 1, end, rightIndex);
// 求得左右子树的最大值赋值给当前索引对应位置
segmentTree[index] = merger.merger(segmentTree[leftIndex], segmentTree[rightIndex]);
return segmentTree;
}
/**
* 计算指定区间数据的和
*
* @param segmentTree 线段树数组对象
* @param end 结束索引
* @param from 指定初始索引
* @param to 指定结束索引
* @return 返回值
*/
public T querySum(T[] segmentTree, int end, int from, int to) {
return querySum(segmentTree, 0, end, 0, from, to);
}
/**
* 计算指定区间数据的和
*
* @param segmentTree 线段树数组对象
* @param start 初始索引
* @param end 结束索引
* @param index 当前索引
* @param from 指定初始索引
* @param to 指定结束索引
* @return 返回值
*/
private T querySum(T[] segmentTree, int start, int end, int index, int from, int to) {
// 结束条件,当初始索引和指定初始索引相等且结束索引和指定结束索引相等时,返回当前索引对应数据
if (start == from && end == to) {
return segmentTree[index];
}
int mid = start + (end - start) / 2;
int leftIndex = index * 2 + 1;
int rightIndex = leftIndex + 1;
// 指定初始索引大于mid时,向右寻找
if (mid < from) {// 往右找
return querySum(segmentTree, mid + 1, end, rightIndex, from, to);
} else if (to <= mid) {// // 指定结束索引小于等于mid时,向左寻找 往左找
return querySum(segmentTree, start, mid, leftIndex, from, to);
}
// 当不属于以上情况时,就需要分区间求和
return merger.merger(querySum(segmentTree, mid + 1, end, rightIndex, mid + 1, to),
querySum(segmentTree, start, mid, leftIndex, from, mid));
}
/**
* 查询指定区间数据的最大值
*
* @param segmentTree 线段树数组对象
* @param end 结束索引
* @param from 指定初始索引
* @param to 指定结束索引
* @return 返回值
*/
public T queryMax(T[] segmentTree, int end, int from, int to) {
return queryMax(segmentTree, 0, end, 0, from, to);
}
/**
* 查询指定区间数据的最大值
*
* @param segmentTree 线段树数组对象
* @param start 初始索引
* @param end 结束索引
* @param index 当前索引
* @param from 指定初始索引
* @param to 指定结束索引
* @return 返回值
*/
private T queryMax(T[] segmentTree, int start, int end, int index, int from, int to) {
// 结束条件,当初始索引和指定初始索引相等且结束索引和指定结束索引相等时,返回当前索引对应数据
if (start == from && end == to) {
return segmentTree[index];
}
int mid = start + (end - start) / 2;
int leftIndex = index * 2 + 1;
int rightIndex = leftIndex + 1;
// 指定初始索引大于mid时,向右寻找
if (mid < from) {// 往右找
return querySum(segmentTree, mid + 1, end, rightIndex, from, to);
} else if (to <= mid) {// // 指定结束索引小于等于mid时,向左寻找 往左找
return querySum(segmentTree, start, mid, leftIndex, from, to);
}
// 当不属于以上情况时,就需要分区间求和
return merger.merger(querySum(segmentTree, mid + 1, end, rightIndex, mid + 1, to),
querySum(segmentTree, start, mid, leftIndex, from, mid));
}
}
测试
package com.company.segment;
/**
* @Author: wenhua
* @CreateTime: 2023-01-12 16:01
*/
public class Main {
public static void main(String[] args) {
Main main = new Main();
Integer[] arr = {1, 3, 5, 7, 9, 11, 13};
main.SumSegmentTest(arr);// 求和线段树
main.MaxSegmentTest(arr);// 最大值线段树
main.querySumTest(arr);// 求和线段树指定区间
main.queryMaxTest(arr);// 求线段树指定区间的最大值
}
public void SumSegmentTest(Integer[] arr) {
SegmentTree<Integer> segmentTree = new SegmentTree<>((Merger<Integer>) (t1, t2) -> t1 + t2);
Integer[] sumSegment = segmentTree.sumSegment(arr);
StringBuffer ssbf = new StringBuffer();
for (int i = 0; i < sumSegment.length; i++) {
ssbf.append(sumSegment[i]);
if (i != sumSegment.length - 1) {
ssbf.append(",");
}
}
System.out.println("求和线段树:" + ssbf.toString());
}
public void MaxSegmentTest(Integer[] arr) {
SegmentTree<Integer> segmentTree = new SegmentTree<Integer>((Merger<Integer>) (t1, t2) -> t1 > t2 ? t1 : t2);
Integer[] maxSegment = segmentTree.maxSegment(arr);
StringBuffer msbf = new StringBuffer();
for (int i = 0; i < maxSegment.length; i++) {
msbf.append(maxSegment[i]);
if (i != maxSegment.length - 1) {
msbf.append(",");
}
}
System.out.println("最大值线段树:" + msbf.toString());
}
public void querySumTest(Integer[] arr) {
SegmentTree<Integer> trie = new SegmentTree<Integer>(new Merger<Integer>() {
@Override
public Integer merger(Integer t1, Integer t2) {
return t1 + t2;
}
});
Integer[] sumSegment = trie.sumSegment(arr);
Integer result = trie.querySum(sumSegment, arr.length - 1, 1, 3);
System.out.println("求和线段树指定区间(1,3):" + result);
}
public void queryMaxTest(Integer[] arr) {
SegmentTree<Integer> trie = new SegmentTree<Integer>(new Merger<Integer>() {
@Override
public Integer merger(Integer t1, Integer t2) {
return t1 > t2 ? t1 : t2;
}
});
Integer[] maxSegment = trie.maxSegment(arr);
StringBuffer msbf = new StringBuffer();
for (int i = 0; i < maxSegment.length; i++) {
msbf.append(maxSegment[i]);
if (i != maxSegment.length - 1) {
msbf.append(",");
}
}
// System.out.println(msbf.toString());
Integer result = trie.queryMax(maxSegment, arr.length - 1, 1, 3);
System.out.println("求线段树指定区间的最大值(1,3):" + result);
}
}
测试结果
求和线段树:49,16,33,4,12,20,13,1,3,5,7,9,11,null,null
最大值线段树:13,7,13,3,7,11,13,1,3,5,7,9,11,null,null
求和线段树指定区间(1,3):15
求线段树指定区间的最大值(1,3):7