package com.yc.tree;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Deque;
import java.util.List;
/**
* @author wb
* @param <T>
*
* 这里我想告诉大家几点:
* 1.区间[i,i+1]表示第i个元素。
*
* 2.我后来发现在区间更新这块有两种方法:第一种就是直接去修改指定区间叶子节点的值(例如区间[1,3]加3,就是[1,1][2,2][3,3]处叶子节点+3);
* 第二种也就是我采用的方法,加延迟标志,因为这种方法正是线段树的精华所在。
*
* 3.我这里采用的还是链式存储,一个很明显的劣势就是当数据项即线段过多时其查询效率明显没有顺序存储结构快。
*
*
* 4.下面我写的代码个人觉得真的烂,就比如在Node的ADT上设计的烂:在求和的时候大量的类型转化
*
*/
public class SegmentTree <T extends Comparable<T>>{
private class Node{
int l,r;//区间范围
T min; //区间最小值
T max;//区间最大值
T sum;//区间和
T delta;//(延迟标志:区间增加值),
T same;//(延迟标志:区间被统一置为某个值),(其实这里可以不要,反而用delta来表示,delta=same-min)
int count; //(延迟标志:区间访问次数) (下面的代码没有用到)
//Node parent;
Node left;
Node right;
public Node(int l, int r, Node left, Node right, int count){
this.l = l;
this.r = r;
this.left = left;
this.right = right;
this.count = count;
}
public String toString(){
return "["+l+"~"+r+":min="+min+",max="+max+",sum="+sum+",delta="+delta+"]";
}
}
//根
private Node root;
public SegmentTree(){
root = null;
}
/*public SegmentTree(int l,int r){
}*/
private T minVal(T t1, T t2){
return t1.compareTo(t2) > 0 ? t2 : t1;
}
private T maxVal(T t1, T t2){
return t1.compareTo(t2) > 0 ? t1 : t2;
}
private int getMid(int t1, int t2){
return t1 + (t2 - t1)/2;
}
/**
* 以指定数组构造链式线段树
* @param data
* @param start
* @param end
* @return
*/
private Node constructUtil(T[] data, int start, int end){
if(root == null){
root = new Node(start, end, null, null, 0);
root.left = constructUtil(data, start, getMid(start,end));
root.right = constructUtil(data, getMid(start,end) + 1, end);
root.min = minVal(root.left.min, root.right.min);
root.max = maxVal(root.left.max, root.right.max);
root.sum = sumOf(root.left.sum, root.right.sum);
return root;
}else{
Node node = new Node(start, end, null, null, 0);
if(start == end){
node.min = node.max = node.sum = data[start];
return node;
}
node.left = constructUtil(data, start, getMid(start,end));
node.right = constructUtil(data, getMid(start,end) + 1, end);
node.min = minVal(node.left.min, node.right.min);
node.max = maxVal(node.left.max, node.right.max);
node.sum = sumOf(node.left.sum, node.right.sum);
return node;
}
}
public void construct(T[] data){
if(data != null){
int length = data.length;
constructUtil(data, 0, length - 1);
}
}
/**
* 最小值获取
* @param n:原始数组的长度
* @param s:开始索引
* @param e:结束索引
* @return
*/
@SuppressWarnings("unchecked")
public T getMin(int n, int s, int e){
if (s < 0 || e > n - 1 || s > e) {
System.out.println("Invalid Input");
return (T)(Integer)(-1);
}
return getMinUtil(root, s, e);
}
@SuppressWarnings("unchecked")
private T getMinUtil(Node node, int s, int e){
Node current = node;
if(root == null){
System.out.println("还未构造线段树!");
return null;
}else{
if(current != null){
int l = current.l;
int r = current.r;
if(s <= l && r <= e){
return current.min;
}
if (e < l || s > r){
return (T)(Integer)Integer.MAX_VALUE;
}
pushDown(node); //这一句代码可以不要,因为已经在update方法中更新完了?
int mid = getMid(l, r);
return minVal(getMinUtil(current.left, s, mid), getMinUtil(current.right, mid + 1, e));
}
return null;
}
}
/**
* 最大值获取
* @param n:原始数组的长度
* @param s:开始索引
* @param e:结束索引
* @return
*/
@SuppressWarnings("unchecked")
public T getMax(int n, int s, int e){
if (s < 0 || e > n - 1 || s > e) {
System.out.println("Invalid Input");
return (T)(Integer)(-1);
}
return getMaxUtil(root,s, e);
}
@SuppressWarnings("unchecked")
private T getMaxUtil(Node node, int s, int e) {
Node current = node;
if(root == null){
System.out.println("还未构造线段树!");
return null;
}else{
int l = current.l;
int r = current.r;
if(s <= l && r <= e){
return current.max;
}
if (e < l || s > r){
return (T)(Integer)Integer.MIN_VALUE;
}
pushDown(node);//这一句代码可以不要,因为已经在update方法中更新完了
int mid = getMid(l, r);
return maxVal(getMaxUtil(current.left, s,mid), getMaxUtil(current.right, mid + 1, e));
}
}
@SuppressWarnings("unchecked")
private T sumOf(T t1, T t2){
return (T)(Integer)(Integer.parseInt(String.valueOf(t1))+Integer.parseInt(String.valueOf(t2)));
}
@SuppressWarnings("unchecked")
private T minuOf(T t1, T t2){
return (T)(Integer)(Integer.parseInt(String.valueOf(t1))-Integer.parseInt(String.valueOf(t2)));
}
/**
* 求给定区间和
* @param n:原始数组的长度
* @param s:开始索引
* @param e:结束索引
* @return
*/
@SuppressWarnings("unchecked")
public T getSum(int n, int s, int e){
if (s < 0 || e > n - 1 || s > e) {
System.out.println("Invalid Input");
return (T)(Integer)(-1);
}
return getSumUtil(root, s, e);
}
@SuppressWarnings("unchecked")
private T getSumUtil(Node node, int s, int e) {
Node current = node;
if(root == null){
System.out.println("还未构造线段树!");
return null;
}else{
int l = current.l;
int r = current.r;
if(s <= l && r <= e){
return current.sum;
}
if (e < l || s > r){
return (T)(Integer)0;
}
pushDown(node);//这一句代码可以不要,因为已经在update方法中更新完了
int mid = getMid(l, r);
return sumOf(getSumUtil(node.left, s,mid),getSumUtil(node.right, mid + 1,e));
}
}
/**
* 在给定的区间添加某值
* @param delte
* @param s
* @param e
*/
public void updateDelta(T delte, int s, int e){
int l = root.l;
int r = root.r;
updateDeltaUtil(root, l, r, delte, s, e);
}
private void updateDeltaUtil(Node node, int l, int r, T delta, int s, int e){
if(node != null){
if(e < l || s > r){
return;
}
if(s <= l && r <= e){
node.delta = node.delta == null ? delta : sumOf(node.delta, delta);
node.min = sumOf(node.min, delta);
node.max = sumOf(node.max, delta);
node.sum = sumOf(node.sum, delta);
return;
}
pushDown(node); //延迟标志向下传递
//更新左右孩子节点
int mid = getMid(l, r);
updateDeltaUtil(node.left, l, mid, delta, s, e);
updateDeltaUtil(node.right, mid + 1, r, delta, s, e);
//根据左右子树的值回溯跟新当前节点的值
node.min = minVal(node.left.min, node.right.min);
node.max = maxVal(node.left.max, node.right.max);
node.sum = sumOf(node.left.sum, node.right.sum);
}
}
/**
* 将区间[s,e]统一设为相同的same值
* @param same
* @param s
* @param e
*/
public void updateSame(T same, int s, int e){
int l = root.l;
int r = root.r;
updateSameUtil(root, l, r, same, s, e);
}
private void updateSameUtil(Node node, int l, int r, T same, int s, int e) {
if(node != null){
if(e < l || r < s){
return;
}
if(s <= l && r <= e){
T delta = minuOf(same, node.min);
node.delta = delta;
node.min = sumOf(node.min, delta);
node.max = sumOf(node.max, delta);
node.sum = sumOf(node.sum, delta);
return;
}
pushDown(node); //延迟标志向下传递
//更新左右孩子节点
int mid = getMid(l, r);
updateSameUtil(node.left, l, mid, same, s, e);
updateSameUtil(node.right, mid + 1, r, same, s, e);
//根据左右子树的值回溯跟新当前节点的值
node.min = minVal(node.left.min, node.right.min);
node.max = maxVal(node.left.max, node.right.max);
node.sum = sumOf(node.left.sum, node.right.sum);
}
}
private void pushDown(Node node){ //delta域向下传递至叶子节点
if(node != null){
if(node.delta != null){
node.left.delta = sumOf(node.left.delta, node.delta);
node.right.delta = sumOf(node.right.delta, node.delta);
node.left.min = sumOf(node.left.min, node.delta);
node.left.max = sumOf(node.left.max, node.delta);
node.left.sum = sumOf(node.left.max, node.delta);
node.right.min = sumOf(node.right.min, node.delta);
node.right.max = sumOf(node.right.max, node.delta);
node.right.sum = sumOf(node.right.max, node.delta);
node.delta = null;
}
}
}
//广度优先遍历
public List<Node> breadthFirstSearch(){
return cBreadthFirstSearch(root);
}
private List<Node> cBreadthFirstSearch(Node node) {
List<Node> nodes = new ArrayList<Node>();
Deque<Node> deque = new ArrayDeque<Node>();
if(node != null){
deque.offer(node);
}
while(!deque.isEmpty()){
Node first = deque.poll();
nodes.add(first);
if(first.left != null){
deque.offer(first.left);
}
if(first.right != null){
deque.offer(first.right);
}
}
return nodes;
}
public static void main(String[] args) {
SegmentTree<Integer> tree = new SegmentTree<Integer>();
Integer[] data = {1,3,8,5,6};
int n = data.length;
tree.construct(data);
System.out.println(tree.breadthFirstSearch());
System.out.println("区间1-3的最小值为:"+tree.getMin(n, 1, 3));
System.out.println( "区间1-4的最大值为"+tree.getMax(n, 1, 4));
tree.updateDelta(3, 1, 3);
//data = {1, 3+3, 8+3, 5+3, 6};
System.out.println(tree.breadthFirstSearch());
System.out.println( "修改delta后区间1-3的最小值为:"+tree.getMin(n, 1, 3));
System.out.println( "修改delta后区间1-4的最大值为"+tree.getMax(n, 1, 4));
tree.updateSame(2, 1, 3);
System.out.println(tree.breadthFirstSearch());
System.out.println( "修改same后区间1-3的最小值为:"+tree.getMin(n, 1, 3));
System.out.println( "修改same后区间1-4的最大值为"+tree.getMax(n, 1, 4));
}
}
测试结果如下:
[[0~4:min=1,max=8,sum=23,delta=null], [0~2:min=1,max=8,sum=12,delta=null], [3~4:min=5,max=6,sum=11,delta=null], [0~1:min=1,max=3,sum=4,delta=null], [2~2:min=8,max=8,sum=8,delta=null], [3~3:min=5,max=5,sum=5,delta=null], [4~4:min=6,max=6,sum=6,delta=null], [0~0:min=1,max=1,sum=1,delta=null], [1~1:min=3,max=3,sum=3,delta=null]]
区间1-3的最小值为:3
区间1-4的最大值为8
[[0~4:min=1,max=11,sum=32,delta=null], [0~2:min=1,max=11,sum=18,delta=null], [3~4:min=6,max=8,sum=14,delta=null], [0~1:min=1,max=6,sum=7,delta=null], [2~2:min=11,max=11,sum=11,delta=3], [3~3:min=8,max=8,sum=8,delta=3], [4~4:min=6,max=6,sum=6,delta=null], [0~0:min=1,max=1,sum=1,delta=null], [1~1:min=6,max=6,sum=6,delta=3]]
修改delta后区间1-3的最小值为:6
修改delta后区间1-4的最大值为11
[[0~4:min=1,max=6,sum=13,delta=null], [0~2:min=1,max=2,sum=5,delta=null], [3~4:min=2,max=6,sum=8,delta=null], [0~1:min=1,max=2,sum=3,delta=null], [2~2:min=2,max=2,sum=2,delta=-9], [3~3:min=2,max=2,sum=2,delta=-6], [4~4:min=6,max=6,sum=6,delta=null], [0~0:min=1,max=1,sum=1,delta=null], [1~1:min=2,max=2,sum=2,delta=-4]]
修改same后区间1-3的最小值为:2
修改same后区间1-4的最大值为6