题目描述
给你一个 2D 矩阵 matrix,请计算出从左上角 (row1, col1) 到右下角 (row2, col2) 组成的矩形中所有元素的和。
上述粉色矩形框内的,该矩形由左上角 (row1, col1) = (2, 1) 和右下角 (row2, col2) = (4, 3) 确定。其中,所包括的元素总和 sum = 8。
示例:
给定 matrix = [
[3, 0, 1, 4, 2],
[5, 6, 3, 2, 1],
[1, 2, 0, 1, 5],
[4, 1, 0, 1, 7],
[1, 0, 3, 0, 5]
]
sumRegion(2, 1, 4, 3) -> 8
update(3, 2, 2)
sumRegion(2, 1, 4, 3) -> 10
解题思路
线段树和树状数组
参考代码
struct SegTreeNode{
int l;
int r;
int sum;
SegTreeNode * left;
SegTreeNode * right;
SegTreeNode(int l,int r){
this->l = l;
this->r = r;
this->sum = 0;
this->left = NULL;
this->right = NULL;
}
};
bool pushupTree(SegTreeNode *root){
if(root->left){
root->sum = root->left->sum;
}
if(root->right){
root->sum += root->right->sum;
}
return true;
}
SegTreeNode * buildTree(vector<int> & arr,int l,int r){
if(l > r){
return NULL;
}
SegTreeNode * root = new SegTreeNode(l,r);
if(l == r){
root->sum = arr[l];
return root;
}
int mid = (l+r)/2;
root->left = buildTree(arr,l,mid);
root->right = buildTree(arr,mid+1,r);
pushupTree(root);
return root;
}
int searchTree(SegTreeNode * root,int l,int r){
if(!root){
return 0;
}
if(root->l >= l && root-> r <= r){
return root->sum;
}
int res = 0;
res += searchTree(root->left,l,r);
res += searchTree(root->right,l,r);
return res;
}
bool updateTree(SegTreeNode * root,int i,int val){
if(!root){
return false;
}
if(root->l == i && root->r == i){
root->sum = val;
return true;
}
int mid = (root->l + root->r)/2;
if(i > mid){
updateTree(root->right,i,val);
}else{
updateTree(root->left,i,val);
}
pushupTree(root);
return true;
}
bool debugTree(SegTreeNode * root){
if(!root){
return false;
}
if(root->l == root->r){
cout<<"["<<root->l<<":"<<root->sum<<"]"<<endl;
return true;
}
debugTree(root->left);
debugTree(root->right);
return true;
}
class NumMatrix {
public:
NumMatrix(vector<vector<int>>& matrix) {
this->row = matrix.size();
if(this->row == 0){
return;
}
this->colum = matrix[0].size();
for(int i = 0;i < matrix.size(); ++i){
this->roots.push_back(buildTree(matrix[i],0,this->colum-1));
}
}
void update(int row, int col, int val) {
updateTree(this->roots[row],col,val);
}
int sumRegion(int row1, int col1, int row2, int col2) {
int res = 0;
for(int i = row1; i <= row2; ++i){
res += searchTree(this->roots[i],col1,col2);
}
return res;
}
private:
int row;
int colum;
vector<SegTreeNode *> roots;;
};
/**
* Your NumMatrix object will be instantiated and called as such:
* NumMatrix* obj = new NumMatrix(matrix);
* obj->update(row,col,val);
* int param_2 = obj->sumRegion(row1,col1,row2,col2);
*/