单点更新–求区间最大(小)值
import java.util.Arrays;
public class BIT_MaxMin {
static final int maxN = 50;
static int N;
static int[] bitMax = new int[maxN];
static int[] bitMin = new int[maxN];
static int[] arr = new int[maxN];
public static void main(String[] args) throws Exception {
int[] in = new int[]{5, 6, 7, 7, 4, 2, 8, 3, 9, 9, 9, 10, 10, 11};
N = in.length;
for (int i = 1; i <= N; i++) {
arr[i] = in[i - 1];
update(i, arr[i]);
}
int[] ans = queryMaxMin(1, 7);
System.out.println("1~7区间最大值: " + ans[0]);//8
System.out.println("1~7区间最小值: " + ans[1]);//2
}
static int lowBit(int x) {
// x & x取反+1
return x & -x;
}
//最大最小可拆分成两个方法
static void update(int x, int val) {
arr[x] = val;
while (x <= maxN) {
bitMax[x] = arr[x];
bitMin[x] = arr[x];
for (int i = 1; i < lowBit(x); i *= 2) {
bitMax[x] = Math.max(bitMax[x], bitMax[x - i]);
bitMin[x] = Math.min(bitMin[x], bitMin[x - i]);
}
x += lowBit(x);
}
}
//最大最小可拆分成两个方法
static int[] queryMaxMin(int x, int y) {
int max = arr[y];
int min = arr[y];
while (y != x) {
for (y--; y - lowBit(y) >= x; y -= lowBit(y)) {
max = Math.max(max, bitMax[y]);
min = Math.min(min, bitMin[y]);
}
max = Math.max(max, arr[y]);
min = Math.min(min, arr[y]);
}
return new int[]{max, min};
}
}
最长递增子序列的长度(LIS、LNDS)
LIS(Longest Increasing Subsequence):最长单调递增子序列长度
LNDS:最长不递减子序列长度
import java.util.Arrays;
import java.util.Comparator;
public class BIT_LIS_LNDS {
static final int maxN = 50;
static int N;
static int lis = 0, lnds = 0;
static Node[] lisArr = new Node[maxN];
static Node[] lndsArr = new Node[maxN];
static int[] bit = new int[maxN];
static Comparator<Node> lisCmp = new Comparator<Node>() {
@Override
public int compare(Node o1, Node o2) {
//lds 按值升序 索引降序
return o1.val == o2.val ? o2.idx - o1.idx : o1.val - o2.val;
}
};
static Comparator<Node> lndsCmp = new Comparator<Node>() {
@Override
public int compare(Node o1, Node o2) {
//lds 按值升序 索引升序
return o1.val == o2.val ? o1.idx - o2.idx : o1.val - o2.val;
}
};
public static void main(String[] args) {
int[] in = new int[]{5, 6, 7, 7, 4, 2, 8, 3, 9, 9, 9, 10, 10, 11};
//lis(7):5 6 7 8 9 10 11
//lnds(11):5 6 7 7 8 9 9 9 10 10 11
N = in.length;
for (int i = 0; i < N; i++) {
lisArr[i + 1] = new Node(i + 1, in[i]);
lndsArr[i + 1] = new Node(i + 1, in[i]);
}
Arrays.sort(lisArr, 1, N + 1, lisCmp);
for (int i = 1; i <= N; i++) {
int len = query(lisArr[i].idx-1) + 1;
update(lisArr[i].idx, len);
lis = Math.max(lis, len);
}
System.out.println("lis: " + lis);
Arrays.fill(bit, 0);
Arrays.sort(lndsArr, 1, N + 1, lndsCmp);
for (int i = 1; i <= N; i++) {
int len = query(lndsArr[i].idx-1) + 1;
update(lndsArr[i].idx, len);
lnds = Math.max(lnds, len);
}
System.out.println("lnds: " + lnds);
}
static int lowBit(int x) {
// x & x取反+1
return x & -x;
}
static void update(int x, int val) {
while (x < maxN) {
bit[x] = Math.max(bit[x], val);
x += lowBit(x);
}
}
static int query(int x) {
int max = 0;
while (x > 0) {
max = Math.max(bit[x], max);
x -= lowBit(x);
}
return max;
}
static class Node {
int idx;
int val;
Node(int idx, int val) {
this.idx = idx;
this.val = val;
}
}
}
最长递减子序列的长度(LDS、LNIS)
LDS(Longest Decreasing Subsequence):最长单调递减子序列长度
LNIS:最长不递增子序列长度
import java.util.Arrays;
import java.util.Comparator;
public class BIT_LDS_LNIS {
static final int maxN = 50;
static int N;
static int lds = 0, lnis = 0;
static Node[] ldsArr = new Node[maxN];
static Node[] lnisArr = new Node[maxN];
static int[] bit = new int[maxN];
static Comparator<Node> ldsCmp = new Comparator<Node>() {
@Override
public int compare(Node o1, Node o2) {
//lds 按值升序 索引升序
return o1.val == o2.val ? o1.idx - o2.idx : o1.val - o2.val;
}
};
static Comparator<Node> lnisCmp = new Comparator<Node>() {
@Override
public int compare(Node o1, Node o2) {
//lnis 按值升序 索引降序
return o1.val == o2.val ? o2.idx - o1.idx : o1.val - o2.val;
}
};
public static void main(String[] args) {
int[] in = new int[]{11, 10, 10, 9, 9, 9, 3, 8, 2, 4, 7, 7, 6, 5};
//lds(7):11 10 9 8 7 6 5
//lnis(11):11 10 10 9 9 9 8 7 7 6 5
N = in.length;
for (int i = 0; i < N; i++) {
ldsArr[i + 1] = new Node(i + 1, in[i]);
lnisArr[i + 1] = new Node(i + 1, in[i]);
}
Arrays.sort(ldsArr, 1, N + 1, ldsCmp);
for (int i = N; i >= 1; i--) {
int len = query(ldsArr[i].idx-1) + 1;
update(ldsArr[i].idx, len);
lds = Math.max(lds, len);
}
System.out.println("lds: " + lds);
Arrays.fill(bit, 0);
Arrays.sort(lnisArr, 1, N + 1, lnisCmp);
for (int i = N; i >= 1; i--) {
int len = query(lnisArr[i].idx-1) + 1;
update(lnisArr[i].idx, len);
lnis = Math.max(lnis, len);
}
System.out.println("lnis: " + lnis);
}
static int lowBit(int x) {
// x & x取反+1
return x & -x;
}
static void update(int x, int val) {
while (x < maxN) {
bit[x] = Math.max(bit[x], val);
x += lowBit(x);
}
}
static int query(int x) {
int max = 0;
while (x > 0) {
max = Math.max(bit[x], max);
x -= lowBit(x);
}
return max;
}
static class Node {
int idx;
int val;
Node(int idx, int val) {
this.idx = idx;
this.val = val;
}
}
}
单点更新–求区间和
public class BIT_Sum {
static final int maxN = 50;
static int N;
static int[] bit = new int[maxN];
public static void main(String[] args) throws Exception {
int[] in = new int[]{5, 6, 7, 7, 4, 2, 8, 3, 9, 9, 9, 10, 10, 11};
N = in.length;
for (int i = 0; i < N; i++) {
update(i + 1, in[i]);//初始化
}
//第x~y项的区间和 query(y) - query(x-1)
//第x项的值改为y update(x, y)
System.out.println(query(5) - query(2 - 1));//2~5项的区间和 6+7+7+4=24
update(3, 10 - 7);//第三项由7改为10
System.out.println(query(5) - query(2 - 1));//2~5项的区间和 6+10+7+4=27
}
static int lowBit(int x) {
return x & (-x);
}
static void update(int x, int val) {
while (x < maxN) {
bit[x] += val;
x += lowBit(x);
}
}
static int query(int x) {
int sum = 0;
while (x > 0) {
sum += bit[x];
x -= lowBit(x);
}
return sum;
}
}
区间更新–单点查询
区间更新,暴力方法:遍历区间,对区间中每个值更新,再求和。
优化方法:使用差分数组
原数组为ai[i], 差分数组di[i]=ai[i]−ai[i−1] (ai[0]=0),则 ai[i]=sum(di[j])
public class BIT_RangeUpdate_SingleQuery {
static final int maxN = 50;
static int N;
static int[] bit = new int[maxN];
public static void main(String[] args) throws Exception {
int[] ai = new int[]{0, 5, 6, 7, 7, 4, 2, 8, 3, 9, 9, 9, 10, 10, 11};
N = ai.length;
// 差分数组di[i] = ai[i] - ai[i-1]
// 0 5 1 1 0 -3 -2 6 -5 6 0 0 1 0 1
int[] di = new int[N];
for (int i = 1; i < N; i++) {
di[i] = ai[i] - ai[i - 1];
update(i, di[i]);
}
System.out.print("差分数组: ");
for (int i = 0; i < N; i++) {
System.out.print(di[i] + " ");
}
System.out.println();
//区间修改,第x~y项的区间 都加2: rangeUpdate(x, y, 2)
//单点查询,查询第x项的值 query(x)
rangeUpdate(2, 7, 2);
System.out.println(query(5));//ai[6]的值4+2=6
}
static int lowBit(int x) {
return x & (-x);
}
/*
区间修改
*/
static void rangeUpdate(int x, int y, int val) {
update(x, val);
update(y + 1, -val);
}
static void update(int x, int val) {
while (x < maxN) {
bit[x] += val;
x += lowBit(x);
}
}
/*
单点查询
*/
static int query(int x) {
int sum = 0;
while (x > 0) {
sum += bit[x];
x -= lowBit(x);
}
return sum;
}
}
区间更新–区间求和
public class BIT_RangeUpdate_RangeQuery {
static final int maxN = 50;
static int N;
static int[] sum1 = new int[maxN];// sum1[i]=di[i]
static int[] sum2 = new int[maxN];// sum2[i]=di[i]∗i
public static void main(String[] args) throws Exception {
int[] ai = new int[]{0, 5, 6, 7, 7, 4, 2, 8, 3, 9, 9, 9, 10, 10, 11};
N = ai.length;
// di[i] = ai[i] - ai[i-1]
// 0 5 1 1 0 -3 -2 6 -5 6 0 0 1 0 1
int[] di = new int[N + 1];
for (int i = 1; i < N; i++) {
di[i] = ai[i] - ai[i - 1];
update(i, di[i]);
}
System.out.print("差分数组: ");
for (int i = 0; i < N; i++) {
System.out.print(di[i] + " ");
}
System.out.println();
//区间修改,第x~y项的区间 都加2: rangeUpdate(x, y, 2)
//区间查询,查询第x~y项区间的和 query(y) - query(x-1)
rangeUpdate(2, 7, 2);
System.out.println(query(5) - query(2 - 1));//ai[3]+...+ai[6]项的区间和 6+10+7+4=27
}
static int lowBit(int x) {
return x & (-x);
}
static void rangeUpdate(int x, int y, int val) {
update(x, val);
update(y + 1, -val);
}
static void update(int x, int val) {
int init = x;
while (x < maxN) {
sum1[x] += val;
sum2[x] += init * val;
x += lowBit(x);
}
}
static int query(int x) {
int init = x;
int sum = 0;
while (x > 0) {
sum += (init + 1) * sum1[x] - sum2[x];
x -= lowBit(x);
}
return sum;
}
}
这个博客写得很好
[https://blog.csdn.net/bestsort/article/details/80796531]