[LeetCode308]Range Sum Query 2D - Mutable

Given a 2D matrix matrix, find the sum of the elements inside the rectangle defined by its upper left corner (row1, col1) and lower right corner (row2, col2).

Range Sum Query 2D
The above rectangle (with the red border) is defined by (row1, col1) = (2, 1) and (row2, col2) = (4, 3), which contains sum = 8.

Example:
Given matrix = [
      [3, 0, 1, 4, 2],
      [5, 6, 3, 2, 1],
      [1, 2, 0, 1, 5],
      [4, 1, 0, 1, 7],
      [1, 0, 3, 0, 5]
    ]

sumRegion(2, 1, 4, 3) -> 8
update(3, 2, 2)
sumRegion(2, 1, 4, 3) -> 10
Note:
    The matrix is only modifiable by the update function.
    You may assume the number of calls to update and sumRegion function is distributed evenly.
    You may assume that row1 ≤ row2 and col1 ≤ col2.

Hide Tags Segment Tree Binary Indexed Tree
Hide Similar Problems (M) Range Sum Query 2D - Immutable (M) Range Sum Query - Mutable

这道题看着超级花骚,然后code也会很长,但其实就是segment tree的变形,1D的懂了就可以慢慢摸索2D的。。。

class SegmentTreeSumNode {
public:
    int sum;
    SegmentTreeSumNode *negibor[4] = {NULL, NULL, NULL, NULL};
    pair<int, int> leftTop = make_pair(0, 0);
    pair<int, int> rightBottom = make_pair(0, 0);
    SegmentTreeSumNode(int sum) : sum(sum){}
};

class NumMatrix {
    SegmentTreeSumNode* root;
    vector<vector<int>> nums;
public:
    NumMatrix(vector<vector<int>> &matrix) {
        if(matrix.empty()) return;
        nums = matrix;
        int m = matrix.size(), n = matrix[0].size();
        if(!m || !n) return;
        root = buildSegmentTree(matrix, make_pair(0,0), make_pair(m-1, n-1));
    }

    void update(int row, int col, int val) {
        int diff = val - nums[row][col];
        if(!diff) return;
        nums[row][col] = val;
        update(row, col, diff, root);
    }

    int sumRegion(int row1, int col1, int row2, int col2) {
        int res = 0;
        if(root) sumRange(row1, col1, row2, col2, root, res);
        return res;
    }
    SegmentTreeSumNode* buildSegmentTree(vector<vector<int>>& matrix, pair<int, int> start, pair<int , int> end){
        if(start.first > end.first || start.second > end.second) return NULL;
        SegmentTreeSumNode* node = new SegmentTreeSumNode(0);
        node->leftTop = start;
        node->rightBottom = end;
        if(start == end){
            node->sum = matrix[start.first][start.second];
            return node;
        }
        int midx = (start.first + end.first)/2;
        int midy = (start.second + end.second)/2;
        node->negibor[0] = buildSegmentTree(matrix, start, make_pair(midx, midy));
        node->negibor[1] = buildSegmentTree(matrix, make_pair(start.first, midy+1), make_pair(midx, end.second));
        node->negibor[2] = buildSegmentTree(matrix, make_pair(midx + 1, start.second), make_pair(end.first, midy));
        node->negibor[3] = buildSegmentTree(matrix, make_pair(midx+1, midy+1), end);
        for (int i = 0; i<4; ++i) {
            if(node->negibor[i]) node->sum += node->negibor[i]->sum;
        }
        return node;
    }

    void update(int row, int col, int diff, SegmentTreeSumNode* node){
        if(row >= (node->leftTop).first && row <= (node->rightBottom).first && col >= (node->leftTop).second && col <= (node->rightBottom).second){
            node->sum += diff;
            for (int i = 0; i<4; ++i) {
                if(node->negibor[i]) update(row, col, diff, node->negibor[i]);
            }
        }
    }

    void sumRange(int row1, int col1, int row2, int col2, SegmentTreeSumNode* node, int& res){
        pair<int, int> start = node->leftTop;
        pair<int, int> end = node->rightBottom;
        int top = max(start.first, row1);
        int bottom = min(end.first, row2);
        if(bottom < top) return;
        int left = max(start.second, col1);
        int right = min(end.second, col2);
        if(left > right) return;
        if(row1 <= start.first && col1 <= start.second && row2 >= end.first && col2 >= end.second){
            res += node->sum;
            return;
        }
        for (int i = 0; i<4; ++i) {
            if(node->negibor[i]) sumRange(row1, col1, row2, col2, node->negibor[i], res);
        }
    }
};

// Your NumMatrix object will be instantiated and called as such:
// NumMatrix numMatrix(matrix);
// numMatrix.sumRegion(0, 1, 2, 3);
// numMatrix.update(1, 1, 10);
// numMatrix.sumRegion(1, 2, 3, 4);

12/15/2015:update:

今天看了Fenwick tree, we can solve the problem with O(logmlogn):

class NumMatrix {
public:
    NumMatrix(vector<vector<int>> &matrix) {
        if (matrix.size() == 0 || matrix[0].size() == 0) return;
        nrow = matrix.size();
        ncol = matrix[0].size();
        nums = matrix;
        BIT = vector<vector<int>> (nrow+1, vector<int>(ncol+1, 0));
        for (int i = 0; i < nrow; i++)
            for (int j = 0; j < ncol; j++)
                add(i, j, matrix[i][j]);

    }

    void update(int row, int col, int val) {
        int diff = val - nums[row][col];
        add(row, col,diff);
        nums[row][col] = val;
    }

    int sumRegion(int row1, int col1, int row2, int col2) {
        int regionL = 0, regionS = 0;
        int regionLeft = 0, regionTop = 0;

        regionL = region(row2, col2);

        if (row1 > 0 && col1 > 0) regionS = region(row1-1, col1-1);

        if (row1 > 0) regionTop  = region(row1-1, col2);

        if (col1 > 0) regionLeft = region(row2, col1-1);       

        return regionL - regionTop - regionLeft + regionS;
    }
private:
    vector<vector<int>> nums;
    vector<vector<int>> BIT;
    int nrow = 0;
    int ncol = 0;
    void add(int row, int col, int val) {
        row++;
        col++;
        while(row <= nrow) {
            int colIdx = col;
            while(colIdx <= ncol) {
                BIT[row][colIdx] += val;
                colIdx += (colIdx & (-colIdx));
            }
            row +=  (row & (-row));
        }
    }

    int region(int row, int col) {
        row++;
        col++;
        int res = 0;
        while(row > 0) {
            int colIdx = col;
            while(colIdx > 0) {
                res += BIT[row][colIdx];
                colIdx -= (colIdx & (-colIdx));
            }
            row -= (row & (-row));
        }
        return res;
    }
};

youtube有个视频讲的蛮好,是1d的。。2d再研究一下吧。
这个可以看一下。。
https://leetcode.com/discuss/71169/java-2d-binary-indexed-tree-solution-clean-and-short-17ms

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值