Java树状数组【Binary Indexed Tree】应用

单点更新–求区间最大(小)值

import java.util.Arrays;

public class BIT_MaxMin {
    static final int maxN = 50;
    static int N;
    static int[] bitMax = new int[maxN];
    static int[] bitMin = new int[maxN];
    static int[] arr = new int[maxN];

    public static void main(String[] args) throws Exception {
        int[] in = new int[]{5, 6, 7, 7, 4, 2, 8, 3, 9, 9, 9, 10, 10, 11};
        N = in.length;
        for (int i = 1; i <= N; i++) {
            arr[i] = in[i - 1];
            update(i, arr[i]);
        }
        int[] ans = queryMaxMin(1, 7);
        System.out.println("1~7区间最大值: " + ans[0]);//8
        System.out.println("1~7区间最小值: " + ans[1]);//2
    }

    static int lowBit(int x) {
        // x & x取反+1
        return x & -x;
    }
    //最大最小可拆分成两个方法
    static void update(int x, int val) {
        arr[x] = val;
        while (x <= maxN) {
            bitMax[x] = arr[x];
            bitMin[x] = arr[x];
            for (int i = 1; i < lowBit(x); i *= 2) {
                bitMax[x] = Math.max(bitMax[x], bitMax[x - i]);
                bitMin[x] = Math.min(bitMin[x], bitMin[x - i]);
            }
            x += lowBit(x);
        }
    }
    //最大最小可拆分成两个方法
    static int[] queryMaxMin(int x, int y) {
        int max = arr[y];
        int min = arr[y];
        while (y != x) {
            for (y--; y - lowBit(y) >= x; y -= lowBit(y)) {
                max = Math.max(max, bitMax[y]);
                min = Math.min(min, bitMin[y]);
            }
            max = Math.max(max, arr[y]);
            min = Math.min(min, arr[y]);
        }
        return new int[]{max, min};
    }
}

最长递增子序列的长度(LIS、LNDS)

LIS(Longest Increasing Subsequence):最长单调递增子序列长度
LNDS:最长不递减子序列长度

import java.util.Arrays;
import java.util.Comparator;

public class BIT_LIS_LNDS {
    static final int maxN = 50;
    static int N;
    static int lis = 0, lnds = 0;
    static Node[] lisArr = new Node[maxN];
    static Node[] lndsArr = new Node[maxN];
    static int[] bit = new int[maxN];
    static Comparator<Node> lisCmp = new Comparator<Node>() {
        @Override
        public int compare(Node o1, Node o2) {
            //lds 按值升序 索引降序
            return o1.val == o2.val ? o2.idx - o1.idx : o1.val - o2.val;
        }
    };
    static Comparator<Node> lndsCmp = new Comparator<Node>() {
        @Override
        public int compare(Node o1, Node o2) {
            //lds 按值升序 索引升序
            return o1.val == o2.val ? o1.idx - o2.idx : o1.val - o2.val;
        }
    };

    public static void main(String[] args) {
        int[] in = new int[]{5, 6, 7, 7, 4, 2, 8, 3, 9, 9, 9, 10, 10, 11};
        //lis(7):5 6 7 8 9 10 11
        //lnds(11):5 6 7 7 8 9 9 9 10 10 11
        N = in.length;
        for (int i = 0; i < N; i++) {
            lisArr[i + 1] = new Node(i + 1, in[i]);
            lndsArr[i + 1] = new Node(i + 1, in[i]);
        }
        Arrays.sort(lisArr, 1, N + 1, lisCmp);
        for (int i = 1; i <= N; i++) {
            int len = query(lisArr[i].idx-1) + 1;
            update(lisArr[i].idx, len);
            lis = Math.max(lis, len);
        }
        System.out.println("lis: " + lis);

        Arrays.fill(bit, 0);
        Arrays.sort(lndsArr, 1, N + 1, lndsCmp);
        for (int i = 1; i <= N; i++) {
            int len = query(lndsArr[i].idx-1) + 1;
            update(lndsArr[i].idx, len);
            lnds = Math.max(lnds, len);
        }
        System.out.println("lnds: " + lnds);
    }

    static int lowBit(int x) {
        // x & x取反+1
        return x & -x;
    }

    static void update(int x, int val) {
        while (x < maxN) {
            bit[x] = Math.max(bit[x], val);
            x += lowBit(x);
        }
    }

    static int query(int x) {
        int max = 0;
        while (x > 0) {
            max = Math.max(bit[x], max);
            x -= lowBit(x);
        }
        return max;
    }

    static class Node {
        int idx;
        int val;

        Node(int idx, int val) {
            this.idx = idx;
            this.val = val;
        }
    }
}

最长递减子序列的长度(LDS、LNIS)

LDS(Longest Decreasing Subsequence):最长单调递减子序列长度
LNIS:最长不递增子序列长度

import java.util.Arrays;
import java.util.Comparator;

public class BIT_LDS_LNIS {
    static final int maxN = 50;
    static int N;
    static int lds = 0, lnis = 0;
    static Node[] ldsArr = new Node[maxN];
    static Node[] lnisArr = new Node[maxN];
    static int[] bit = new int[maxN];

    static Comparator<Node> ldsCmp = new Comparator<Node>() {
        @Override
        public int compare(Node o1, Node o2) {
            //lds 按值升序 索引升序
            return o1.val == o2.val ? o1.idx - o2.idx : o1.val - o2.val;
        }
    };
    static Comparator<Node> lnisCmp = new Comparator<Node>() {
        @Override
        public int compare(Node o1, Node o2) {
            //lnis 按值升序 索引降序
            return o1.val == o2.val ? o2.idx - o1.idx : o1.val - o2.val;
        }
    };

    public static void main(String[] args) {
        int[] in = new int[]{11, 10, 10, 9, 9, 9, 3, 8, 2, 4, 7, 7, 6, 5};
        //lds(7):11 10 9 8 7 6 5
        //lnis(11):11 10 10 9 9 9 8 7 7 6 5
        N = in.length;
        for (int i = 0; i < N; i++) {
            ldsArr[i + 1] = new Node(i + 1, in[i]);
            lnisArr[i + 1] = new Node(i + 1, in[i]);
        }
        Arrays.sort(ldsArr, 1, N + 1, ldsCmp);
        for (int i = N; i >= 1; i--) {
            int len = query(ldsArr[i].idx-1) + 1;
            update(ldsArr[i].idx, len);
            lds = Math.max(lds, len);
        }
        System.out.println("lds: " + lds);

        Arrays.fill(bit, 0);
        Arrays.sort(lnisArr, 1, N + 1, lnisCmp);
        for (int i = N; i >= 1; i--) {
            int len = query(lnisArr[i].idx-1) + 1;
            update(lnisArr[i].idx, len);
            lnis = Math.max(lnis, len);
        }
        System.out.println("lnis: " + lnis);
    }

    static int lowBit(int x) {
        // x & x取反+1
        return x & -x;
    }

    static void update(int x, int val) {
        while (x < maxN) {
            bit[x] = Math.max(bit[x], val);
            x += lowBit(x);
        }
    }

    static int query(int x) {
        int max = 0;
        while (x > 0) {
            max = Math.max(bit[x], max);
            x -= lowBit(x);
        }
        return max;
    }

    static class Node {
        int idx;
        int val;

        Node(int idx, int val) {
            this.idx = idx;
            this.val = val;
        }
    }
}

单点更新–求区间和

public class BIT_Sum {
    static final int maxN = 50;
    static int N;
    static int[] bit = new int[maxN];

    public static void main(String[] args) throws Exception {
        int[] in = new int[]{5, 6, 7, 7, 4, 2, 8, 3, 9, 9, 9, 10, 10, 11};
        N = in.length;
        for (int i = 0; i < N; i++) {
            update(i + 1, in[i]);//初始化
        }
        //第x~y项的区间和 query(y) - query(x-1)
        //第x项的值改为y update(x, y)
        System.out.println(query(5) - query(2 - 1));//2~5项的区间和 6+7+7+4=24
        update(3, 10 - 7);//第三项由7改为10
        System.out.println(query(5) - query(2 - 1));//2~5项的区间和 6+10+7+4=27
    }

    static int lowBit(int x) {
        return x & (-x);
    }

    static void update(int x, int val) {
        while (x < maxN) {
            bit[x] += val;
            x += lowBit(x);
        }
    }

    static int query(int x) {
        int sum = 0;
        while (x > 0) {
            sum += bit[x];
            x -= lowBit(x);
        }
        return sum;
    }
}

区间更新–单点查询

区间更新,暴力方法:遍历区间,对区间中每个值更新,再求和。
优化方法:使用差分数组
原数组为ai[i], 差分数组di[i]=ai[i]−ai[i−1] (ai[0]=0),则 ai[i]=sum(di[j])

public class BIT_RangeUpdate_SingleQuery {
    static final int maxN = 50;
    static int N;
    static int[] bit = new int[maxN];

    public static void main(String[] args) throws Exception {
        int[] ai = new int[]{0, 5, 6, 7, 7, 4, 2, 8, 3, 9, 9, 9, 10, 10, 11};
        N = ai.length;
        // 差分数组di[i] = ai[i] - ai[i-1]
        // 0 5 1 1 0 -3 -2 6 -5 6 0 0 1 0 1
        int[] di = new int[N];
        for (int i = 1; i < N; i++) {
            di[i] = ai[i] - ai[i - 1];
            update(i, di[i]);
        }

        System.out.print("差分数组: ");
        for (int i = 0; i < N; i++) {
            System.out.print(di[i] + " ");
        }
        System.out.println();
        //区间修改,第x~y项的区间 都加2: rangeUpdate(x, y, 2)
        //单点查询,查询第x项的值 query(x)
        rangeUpdate(2, 7, 2);
        System.out.println(query(5));//ai[6]的值4+2=6
    }

    static int lowBit(int x) {
        return x & (-x);
    }

    /*
    区间修改
     */
    static void rangeUpdate(int x, int y, int val) {
        update(x, val);
        update(y + 1, -val);
    }

    static void update(int x, int val) {
        while (x < maxN) {
            bit[x] += val;
            x += lowBit(x);
        }
    }

    /*
    单点查询
     */
    static int query(int x) {
        int sum = 0;
        while (x > 0) {
            sum += bit[x];
            x -= lowBit(x);
        }
        return sum;
    }
}

区间更新–区间求和

public class BIT_RangeUpdate_RangeQuery {
    static final int maxN = 50;
    static int N;
    static int[] sum1 = new int[maxN];// sum1[i]=di[i]
    static int[] sum2 = new int[maxN];// sum2[i]=di[i]∗i

    public static void main(String[] args) throws Exception {
        int[] ai = new int[]{0, 5, 6, 7, 7, 4, 2, 8, 3, 9, 9, 9, 10, 10, 11};
        N = ai.length;
        // di[i] = ai[i] - ai[i-1]
        // 0 5 1 1 0 -3 -2 6 -5 6 0 0 1 0 1
        int[] di = new int[N + 1];
        for (int i = 1; i < N; i++) {
            di[i] = ai[i] - ai[i - 1];
            update(i, di[i]);
        }

        System.out.print("差分数组: ");
        for (int i = 0; i < N; i++) {
            System.out.print(di[i] + " ");
        }
        System.out.println();
        //区间修改,第x~y项的区间 都加2: rangeUpdate(x, y, 2)
        //区间查询,查询第x~y项区间的和 query(y) - query(x-1)
        rangeUpdate(2, 7, 2);
        System.out.println(query(5) - query(2 - 1));//ai[3]+...+ai[6]项的区间和 6+10+7+4=27
    }

    static int lowBit(int x) {
        return x & (-x);
    }

    static void rangeUpdate(int x, int y, int val) {
        update(x, val);
        update(y + 1, -val);
    }

    static void update(int x, int val) {
        int init = x;
        while (x < maxN) {
            sum1[x] += val;
            sum2[x] += init * val;
            x += lowBit(x);
        }
    }

    static int query(int x) {
        int init = x;
        int sum = 0;
        while (x > 0) {
            sum += (init + 1) * sum1[x] - sum2[x];
            x -= lowBit(x);
        }
        return sum;
    }
}

这个博客写得很好
[https://blog.csdn.net/bestsort/article/details/80796531]

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值