借此写一下二维树状数组
1.单点修改,区间查询
click here
区间查询,因为a[i][j]表示从(i,j)到(0,0)的所有值之和,因此需要减去多计算的部分
#include<stdio.h>
#include<string.h>
#include<algorithm>
using namespace std;
const int maxn=4100;
int n,m;
typedef long long ll;
ll a[maxn][maxn];
int lowbit(int x)
{
return x&(-x);
}
void update(int x,int y,ll val)
{
for(int i=x;i<=n;i+=lowbit(i))
for(int j=y;j<=m;j+=lowbit(j))
a[i][j]+=val;
}
ll getsum(int x,int y)
{
ll res=0;
for(int i=x;i>0;i-=lowbit(i))
for(int j=y;j>0;j-=lowbit(j))
res+=a[i][j];
return res;
}
int main(){
scanf("%d%d",&n,&m);
memset(a,0,sizeof a);
int t;
while(~scanf("%d",&t))
{
int x1,y1,x2,y2;
ll val;
if(t==1)
{
scanf("%d%d%lld",&x1,&y1,&val);
update(x1,y1,val);
}
else
{
scanf("%d%d%d%d",&x1,&y1,&x2,&y2);
printf("%lld\n",getsum(x2,y2)-getsum(x1-1,y2)-getsum(x2,y1-1)+getsum(x1-1,y1-1));
}
}
return 0;
}
2.区间修改,单点查询
click here
区间修改与区间查询是一样的道理
#include<stdio.h>
#include<string.h>
#include<algorithm>
using namespace std;
int n,m;
typedef long long ll;
const int maxn=4100;
ll a[maxn][maxn];
int lowbit(int x)
{
return x&(-x);
}
void update(int x1,int y1,int val)
{
for(int x=x1;x<=n;x+=lowbit(x))
for(int y=y1;y<=m;y+=lowbit(y))
a[x][y]+=val;
}
ll getsum(int x,int y)
{
ll res=0;
for(int i=x;i>0;i-=lowbit(i))
for(int j=y;j>0;j-=lowbit(j))
res+=a[i][j];
return res;
}
int main()
{
scanf("%d%d",&n,&m);
memset(a,0,sizeof a);
int t;
while(~scanf("%d",&t))
{
int x1,y1,x2,y2,val;
if(t==1)
{
scanf("%d%d%d%d%d",&x1,&y1,&x2,&y2,&val);
update(x1,y1,val);
update(x1,y2+1,-val);
update(x2+1,y1,-val);
update(x2+1,y2+1,val);
}
else
{
scanf("%d%d",&x1,&y1);
printf("%lld\n",getsum(x1,y1));
}
}
return 0;
}
3.区间修改,区间查询
click here
#include<stdio.h>
#include<string.h>
#include<algorithm>
using namespace std;
const int maxn=4100;
typedef long long ll;
ll a[maxn][maxn];
ll s1[maxn][maxn],s2[maxn][maxn],s3[maxn][maxn],s4[maxn][maxn];
ll n,m;
int lowbit(int x)
{
return x&(-x);
}
void update(int x,int y,int val)
{
for(int i=x;i<=n;i+=lowbit(i))
for(int j=y;j<=m;j+=lowbit(j))
{
s1[i][j]+=val;
s2[i][j]+=val*x;
s3[i][j]+=val*y;
s4[i][j]+=val*x*y;
}
}
ll getsum(int x,int y)
{
ll res=0;
for(int i=x;i>0;i-=lowbit(i))
for(int j=y;j>0;j-=lowbit(j))
{
res+=s1[i][j]*(x+1)*(y+1)-s2[i][j]*(y+1)-s3[i][j]*(x+1)+s4[i][j];
}
return res;
}int main()
{
scanf("%lld%lld",&n,&m);
ll t;
while(~scanf("%lld",&t))
{
ll x1,y1,x2,y2,val;
if(t==1)
{
scanf("%lld%lld%lld%lld%lld",&x1,&y1,&x2,&y2,&val);
update(x1,y1,val);
update(x1,y2+1,-val);
update(x2+1,y1,-val);
update(x2+1,y2+1,val);
}
else
{
scanf("%d%d%d%d",&x1,&y1,&x2,&y2);
ll res=getsum(x2,y2)-getsum(x1-1,y2)-getsum(x2,y1-1)+getsum(x1-1,y1-1);
printf("%lld\n",res);
}
}
return 0;
}