Segment Tree 的基本操作 Segment Tree Build, Segment Tree Query, Segment Tree Modify 必须熟练掌握;
线段树长什么样子,就是上面的样子,注意到数组A并不要求是sort的,range 是index的范围,每个leaf节点就是A中每个element.
Interval Sum 思路:其实,这个题目是为了后面的follow up做准备的,LogN time for query; 如果数组是个动态的, 那么就要用到segment tree或者树状数组;现在这里用segment tree来做解答;Time: m Log(K) , m是query的次数,K是数组里面的最大值;Space: O(N)
/**
* Definition of Interval:
* public classs Interval {
* int start, end;
* Interval(int start, int end) {
* this.start = start;
* this.end = end;
* }
* }
*/
public class Solution {
/**
* @param A: An integer list
* @param queries: An query list
* @return: The result list
*/
public List<Long> intervalSum(int[] A, List<Interval> queries) {
List<Long> list = new ArrayList<>();
SegmentTree segmentTree = new SegmentTree(A);
for(Interval interval: queries) {
list.add(segmentTree.query(interval.start, interval.end));
}
return list;
}
private class SegmentTreeNode {
public int start, end;
public long sum;
public SegmentTreeNode left, right;
public SegmentTreeNode (int start, int end) {
this.start = start;
this.end = end;
this.sum = 0;
this.left = null;
this.right = null;
}
}
private class SegmentTree {
public SegmentTreeNode root;
public int size;
public SegmentTree(int[] A) {
this.size = A.length;
this.root = buildTree(A, 0, size - 1);
}
private SegmentTreeNode buildTree(int[] A, int start, int end) {
if(start > end) {
return null;
}
SegmentTreeNode node = new SegmentTreeNode(start, end);
if(start == end) {
node.sum = A[start];
return node;
}
int mid = start + (end - start) / 2;
node.left = buildTree(A, start, mid);
node.right = buildTree(A, mid + 1, end);
node.sum = node.left.sum + node.right.sum;
return node;
}
private long querySum(SegmentTreeNode root, int start, int end) {
if(root.start == start && root.end == end) {
return root.sum;
}
int mid = root.start + (root.end - root.start) / 2;
long leftsum = 0, rightsum = 0;
if(start <= mid) {
leftsum = querySum(root.left, start, Math.min(mid, end));
}
if(end >= mid + 1) {
rightsum = querySum(root.right, Math.max(start, mid + 1), end);
}
return leftsum + rightsum;
}
public long query(int start, int end) {
return querySum(root, start, end);
}
}
}
Interval sum II 思路:如果数组modify的话,那么prefixsum就没有什么用了,因为sum需要O(N),所以这题正确的解法还是segment Tree
public class Solution {
/* you may need to use some attributes here */
/*
* @param A: An integer array
*/
private SegmentTree tree;
public Solution(int[] A) {
tree = new SegmentTree(A);
}
/*
* @param start: An integer
* @param end: An integer
* @return: The sum from start to end
*/
public long query(int start, int end) {
return tree.querySum(start, end);
}
/*
* @param index: An integer
* @param value: An integer
* @return: nothing
*/
public void modify(int index, int value) {
tree.modify(index, value);
}
private class SegmentTreeNode {
public int start, end;
public long sum;
public SegmentTreeNode left, right;
public SegmentTreeNode(int start, int end) {
this.start = start;
this.end = end;
this.sum = 0;
this.left = null;
this.right = null;
}
}
private class SegmentTree {
private SegmentTreeNode root;
private int size;
public SegmentTree(int[] A) {
this.size = A.length;
this.root = buildTree(A, 0, size - 1);
}
private SegmentTreeNode buildTree(int[] A, int start, int end) {
if(start > end) {
return null;
}
SegmentTreeNode node = new SegmentTreeNode(start, end);
if(start == end) {
node.sum = A[start];
return node;
}
int mid = start + (end - start) / 2;
node.left = buildTree(A, start, mid);
node.right = buildTree(A, mid + 1, end);
node.sum = node.left.sum + node.right.sum;
return node;
}
private long querySum(SegmentTreeNode root, int start, int end) {
if(root.start == start && root.end == end) {
return root.sum;
}
int mid = root.start + (root.end - root.start) / 2;
long leftsum = 0, rightsum = 0;
if(start <= mid) {
leftsum = querySum(root.left, start, Math.min(mid, end));
}
if(end >= mid + 1) {
rightsum = querySum(root.right, Math.max(start, mid + 1), end);
}
return leftsum + rightsum;
}
private void modify(SegmentTreeNode root, int index, int value) {
if(root.start == root.end && root.end == index) {
root.sum = value;
return;
}
int mid = root.start + (root.end - root.start) / 2;
if(index <= mid) {
modify(root.left, index, value);
} else {
modify(root.right, index, value);
}
root.sum = root.left.sum + root.right.sum;
}
public long querySum(int start, int end) {
return querySum(root, start, end);
}
public void modify(int index, int value) {
modify(root, index, value);
}
}
}
Count of Smaller Number 因为A已经给定了,所以可以用统计好的数组B来build tree;建立树是O(N),查询是mlog(n)
public class Solution {
/**
* @param A: An integer array
* @param queries: The query list
* @return: The number of element in the array that are smaller that the given integer
*/
public List<Integer> countOfSmallerNumber(int[] A, int[] queries) {
List<Integer> list = new ArrayList<Integer>();
int[] B = new int[10001];
for(int i : A) {
B[i]++;
}
SegmentTree tree = new SegmentTree(B);
for(int i : queries) {
list.add(tree.querySum(0, i - 1));
}
return list;
}
private class SegmentTreeNode {
public int start, end;
public SegmentTreeNode left, right;
public int sum;
public SegmentTreeNode(int start, int end) {
this.start = start;
this.end = end;
this.sum = 0;
this.left = null;
this.right = null;
}
}
private class SegmentTree {
private SegmentTreeNode root;
private int size;
public SegmentTree(int[] A) {
this.size = A.length;
this.root = buildTree(A, 0, size - 1);
}
private SegmentTreeNode buildTree(int[] A, int start, int end) {
if(start > end) {
return null;
}
SegmentTreeNode root = new SegmentTreeNode(start, end);
if(start == end) {
root.sum = A[start];
return root;
}
int mid = start + (end - start) / 2;
root.left = buildTree(A, start, mid);
root.right = buildTree(A, mid + 1, end);
root.sum = root.left.sum + root.right.sum;
return root;
}
private int querySum(SegmentTreeNode root, int start, int end) {
if(root.start == start && root.end == end) {
return root.sum;
}
int mid = root.start + (root.end - root.start) / 2;
int leftsum = 0, rightsum = 0;
if(start <= mid) {
leftsum = querySum(root.left, start, Math.min(mid, end));
}
if(end >= mid + 1) {
rightsum = querySum(root.right, Math.max(start, mid + 1), end);
}
return leftsum + rightsum;
}
public int querySum(int start, int end) {
return querySum(root, start, end);
}
}
}
这里可以稍微优化一下,可以省去一个10001size的数组;把modify稍微改一下,变成add,建立一个空segment tree,然后一点点的往里面modify添加数据;建立数是O(N),查询是mlog(n)
public class Solution {
/**
* @param A: An integer array
* @param queries: The query list
* @return: The number of element in the array that are smaller that the given integer
*/
public List<Integer> countOfSmallerNumber(int[] A, int[] queries) {
List<Integer> list = new ArrayList<Integer>();
SegmentTree tree = new SegmentTree(10001);
for(int i : A) {
tree.add(i, 1);
}
for(int i : queries) {
if(i == 0) {
list.add(0);
} else {
list.add(tree.querySum(0, i -1));
}
}
return list;
}
private class SegmentTreeNode {
public int start, end;
public SegmentTreeNode left, right;
public int sum;
public SegmentTreeNode(int start, int end) {
this.start = start;
this.end = end;
this.sum = 0;
this.left = null;
this.right = null;
}
}
private class SegmentTree {
private SegmentTreeNode root;
private int size;
public SegmentTree(int size) {
this.size = size;
this.root = buildTree(0, size - 1);
}
private SegmentTreeNode buildTree(int start, int end) {
if(start > end) {
return null;
}
SegmentTreeNode root = new SegmentTreeNode(start, end);
if(start == end) {
return root;
}
int mid = start + (end - start) / 2;
root.left = buildTree(start, mid);
root.right = buildTree(mid + 1, end);
return root;
}
private int querySum(SegmentTreeNode root, int start, int end) {
if(root.start == start && root.end == end) {
return root.sum;
}
int mid = root.start + (root.end - root.start) / 2;
int leftsum = 0, rightsum = 0;
if(start <= mid) {
leftsum = querySum(root.left, start, Math.min(mid, end));
}
if(end >= mid + 1) {
rightsum = querySum(root.right, Math.max(start, mid + 1), end);
}
return leftsum + rightsum;
}
private void add(SegmentTreeNode root, int index, int value) {
if(root.start == root.end && root.end == index) {
root.sum += value;
return;
}
int mid = root.start + (root.end - root.start) / 2;
if(index <= mid) {
add(root.left, index, value);
} else {
add(root.right, index, value);
}
root.sum = root.left.sum + root.right.sum;
}
public int querySum(int start, int end) {
return querySum(root, start, end);
}
public void add(int index, int value) {
add(root, index, value);
}
}
}
Count of Smaller Number before itself. 思路:用segment tree来记录当前点之前的所有的count,并且维护线段树的count,首先建立一个空的线段树,然后每次遇见一个数,把他的index node count值++,最后 querySum (0, A[i]-1)的线段count即可,注意一定要做一个特殊例子判断,就是A[i] == 0,比0小的数没有,因为数组是[0,10000] 的数,所以,没有,list直接加0,但是0这个node的count还是要++,因为他是其他比0大的数的小于的count;
public class Solution {
/**
* @param A: an integer array
* @return: A list of integers includes the index of the first number and the index of the last number
*/
public List<Integer> countOfSmallerNumberII(int[] A) {
List<Integer> list = new ArrayList<>();
SegmentTree segmentTree = new SegmentTree(10001);
for(int num: A) {
if(num == 0) {
list.add(0);
} else {
list.add(segmentTree.querySum(0, num - 1));
}
segmentTree.add(num, 1);
}
return list;
}
public class SegmentTreeNode {
public int start, end;
public int sum;
public SegmentTreeNode left, right;
public SegmentTreeNode (int start, int end) {
this.start = start;
this.end = end;
this.sum = 0;
this.left = null;
this.right = null;
}
}
public class SegmentTree {
public SegmentTreeNode root;
public int size;
public SegmentTree (int size) {
this.size = size;
this.root = buildTree(0, size - 1);
}
public SegmentTreeNode buildTree(int start, int end) {
if(start > end) {
return null;
}
SegmentTreeNode node = new SegmentTreeNode(start, end);
if(start == end) {
return node;
}
int mid = start + (end - start) / 2;
node.left = buildTree(start, mid);
node.right = buildTree(mid + 1, end);
return node;
}
public int querySum(int start, int end) {
return querySumHelper(root, start, end);
}
public int querySumHelper(SegmentTreeNode node, int start, int end) {
if(node.start == start && node.end == end) {
return node.sum;
}
int mid = node.start + (node.end - node.start) / 2;
// node.start..........mid.......node.end;
// start.......end;
int leftsum = 0, rightsum = 0;
if(start <= mid) {
leftsum = querySumHelper(node.left, start, Math.min(mid, end));
}
if(mid + 1 <= end) {
rightsum = querySumHelper(node.right, Math.max(mid + 1, start), end);
}
return leftsum + rightsum;
}
public void add(int index, int value) {
addHelper(root, index, value);
}
public void addHelper(SegmentTreeNode node, int index, int value) {
if(node.start == node.end && node.end == index) {
node.sum += value;
return;
}
int mid = node.start + (node.end - node.start) / 2;
if(index <= mid) {
addHelper(node.left, index, value);
}
if(mid + 1 <= index) {
addHelper(node.right, index, value);
}
node.sum = node.left.sum + node.right.sum;
}
}
}
Count of Smaller Numbers After Self O(NlogN)算法沿用Count of Smaller numbers 和 count of smaller numbers before self。还是用线段树求解,但是这个题目有负数的情况,解决方法很简单,就是求得max和min之后,把整个数组解空间往右shift min位子就可以了。注意如果min > 0, 那么没必要shift; 那么segment tree的class一点都不需要改变,只需要把count的代码,每次+ Math.abs(min);把before的代码稍微改改就可以用了,注意题目要求是逆着count,所以list add的时候是add(0,count);
class Solution {
public List<Integer> countSmaller(int[] nums) {
List<Integer> list = new ArrayList<Integer>();
if(nums == null || nums.length == 0) {
return list;
}
int minvalue = nums[0], maxvalue = nums[0];
for(int i : nums) {
minvalue = Math.min(minvalue, i);
maxvalue = Math.max(maxvalue, i);
}
// 因为有负数,所以需要shift 坐标 math.abs(minvalue), 但是如果minvalue > 0,则没必要shift;
minvalue = Math.min(minvalue, 0);
int size = maxvalue - minvalue + 1;
SegmentTree tree = new SegmentTree(size);
for(int i = nums.length - 1; i >= 0; i--) {
if(nums[i] == minvalue) {
list.add(0, 0);
} else {
list.add(0, tree.querySum(0, nums[i] + Math.abs(minvalue) - 1));
}
tree.add(nums[i] + Math.abs(minvalue), 1);
}
return list;
}
private class SegmentTreeNode {
public int start, end;
public int sum;
public SegmentTreeNode left, right;
public SegmentTreeNode(int start, int end) {
this.start = start;
this.end = end;
}
}
private class SegmentTree {
public SegmentTreeNode root;
public int size;
public SegmentTree(int size) {
this.size = size;
this.root = buildTree(0, size - 1);
}
private SegmentTreeNode buildTree(int start, int end) {
if(start > end) {
return null;
}
SegmentTreeNode root = new SegmentTreeNode(start, end);
if(start == end) {
return root;
}
int mid = start + (end - start) / 2;
root.left = buildTree(start, mid);
root.right = buildTree(mid + 1, end);
return root;
}
private int querySum(SegmentTreeNode root, int start, int end) {
if(root.start == start && root.end == end) {
return root.sum;
}
int mid = root.start + (root.end - root.start) / 2;
int leftsum = 0, rightsum = 0;
if(start <= mid) {
leftsum = querySum(root.left, start, Math.min(mid, end));
}
if(end >= mid + 1) {
rightsum = querySum(root.right, Math.max(start, mid + 1), end);
}
return leftsum + rightsum;
}
private void add(SegmentTreeNode root, int index, int value) {
if(root.start == root.end && root.end == index) {
root.sum += value;
return;
}
int mid = root.start + (root.end - root.start) / 2;
if(index <= mid) {
add(root.left, index, value);
} else {
add(root.right, index, value);
}
root.sum = root.left.sum + root.right.sum;
}
public int querySum(int start, int end) {
return querySum(root, start, end);
}
public void add(int index, int value) {
add(root, index, value);
}
}
}
Range Sum Query - Mutable 思路:用segment tree,可以直接用array build tree,也可以先建立一个棵空的segment tree,然后每个index modify成(i, nums[i]),然后按照模板写出segment tree即可;Time Build Tree O(N), query , modify O(logN).
class NumArray {
private SegmentTree tree;
public NumArray(int[] nums) {
tree = new SegmentTree(nums);
}
public void update(int i, int val) {
tree.modify(i, val);
}
public int sumRange(int i, int j) {
return tree.querySum(i, j);
}
private class SegmentTreeNode {
public int start, end;
public int sum;
public SegmentTreeNode left, right;
public SegmentTreeNode (int start, int end) {
this.start = start;
this.end = end;
}
}
private class SegmentTree {
public SegmentTreeNode root;
public int size;
public SegmentTree(int[] A) {
this.size = A.length;
this.root = buildTree(A, 0, size - 1);
}
private SegmentTreeNode buildTree(int[] A, int start, int end) {
if(start > end) {
return null;
}
SegmentTreeNode root = new SegmentTreeNode(start, end);
if(start == end) {
root.sum = A[start];
return root;
}
int mid = start + (end - start) / 2;
root.left = buildTree(A, start, mid);
root.right = buildTree(A, mid + 1, end);
root.sum = root.left.sum + root.right.sum;
return root;
}
private int querySum(SegmentTreeNode root, int start, int end) {
if(root.start == start && root.end == end) {
return root.sum;
}
int mid = root.start + (root.end - root.start) / 2;
int leftsum = 0, rightsum = 0;
if(start <= mid) {
leftsum = querySum(root.left, start, Math.min(mid, end));
}
if(end >= mid + 1) {
rightsum = querySum(root.right, Math.max(start, mid + 1), end);
}
return leftsum + rightsum;
}
private void modify(SegmentTreeNode root, int index, int value) {
if(root.start == root.end && root.end == index) {
root.sum = value;
return;
}
int mid = root.start + (root.end - root.start) / 2;
if(index <= mid) {
modify(root.left, index, value);
} else {
modify(root.right, index, value);
}
root.sum = root.left.sum + root.right.sum;
}
public int querySum(int start, int end) {
return querySum(root, start, end);
}
public void modify(int index, int value) {
modify(root, index, value);
}
}
}
/**
* Your NumArray object will be instantiated and called as such:
* NumArray obj = new NumArray(nums);
* obj.update(i,val);
* int param_2 = obj.sumRange(i,j);
*/