感觉这题代码有点略长,然后我有个地方写错了,找bug找了好久,今天终于找出来了。。。 坑爹啊。。。
更新的时候注意先考虑set操作,若某节点当前进行了set操作,则该节点之前的add操作都应该去除
代码:
#include <stdio.h>
#include <string.h>
#define maxn 50005
struct node
{
int l, r, sum;
int max, min, ad, set;
}a[22][4*maxn];
int Max(int x, int y)
{
return x> y? x: y;
}
int Min(int x, int y)
{
return x< y? x: y;
}
void Build(int l, int r, int m, int n)
{
a[m][n].l= l;
a[m][n].r= r;
a[m][n].sum= a[m][n].max= a[m][n].min= a[m][n].ad= a[m][n].set= 0;
if(a[m][n].l == a[m][n].r)
return;
int mid= (a[m][n].l + a[m][n].r)/ 2;
Build(l, mid, m, 2*n);
Build(mid+1, r, m, 2*n+1);
}
void pushdown(int m, int n)
{
if(a[m][n].set)
{
a[m][2*n].set= a[m][2*n+1].set= a[m][n].set;
a[m][n].set= 0;
a[m][2*n].ad= a[m][2*n+1].ad= 0;// 父节点的set往下传,并清除孩子的ad
a[m][2*n].max= a[m][2*n].min= a[m][2*n].set;
a[m][2*n].sum= (a[m][2*n].r - a[m][2*n].l + 1)* a[m][2*n].set;
a[m][2*n+1].max= a[m][2*n+1].min= a[m][2*n+1].set;
a[m][2*n+1].sum= (a[m][2*n+1].r - a[m][2*n+1].l + 1)* a[m][2*n+1].set;
}
if(a[m][n].ad)
{
a[m][2*n].ad+= a[m][n].ad;
a[m][2*n].sum+= (a[m][2*n].r - a[m][2*n].l + 1)* a[m][n].ad;
a[m][2*n].max+= a[m][n].ad;
a[m][2*n].min+= a[m][n].ad;
//父节点的ad传向左孩子,并更新左孩子节点的内容
//注意这里a[m][2*n].ad!= a[m][n].ad ,但是进行sz操作时 a[m][n].sz= a[m][2*n].sz
//我之前就是这里写成了a[m][2*n].ad 所以一直无限WA
a[m][2*n+1].ad+= a[m][n].ad;
a[m][2*n+1].sum+= (a[m][2*n+1].r- a[m][2*n+1].l + 1)*a[m][n].ad;
a[m][2*n+1].max+= a[m][n].ad;
a[m][2*n+1].min+= a[m][n].ad;
a[m][n].ad= 0; //清除父节点的ad
}
}
void Add(int l, int r, int m, int n, int num)
{
if(l== a[m][n].l && r== a[m][n].r)
{
a[m][n].ad+= num;
a[m][n].sum+= (r- l+ 1)*num;
a[m][n].max+= num;
a[m][n].min+= num;
return;
}
pushdown(m, n);
int mid= (a[m][n].l + a[m][n].r)/ 2;
if(r<= mid)
Add(l, r, m, 2*n, num);
else if(l> mid)
Add(l, r, m, 2*n+1, num);
else
{
Add(l, mid, m, 2*n, num);
Add(mid+1, r, m, 2*n+1, num);
}
a[m][n].sum= a[m][2*n].sum + a[m][2*n+1].sum;
a[m][n].max= Max(a[m][2*n].max, a[m][2*n+1].max);
a[m][n].min= Min(a[m][2*n].min, a[m][2*n+1].min);
}
void update(int l, int r, int m, int n, int num)
{
if(a[m][n].l== l && a[m][n].r== r)
{
a[m][n].set= num;
a[m][n].sum= (a[m][n].r- a[m][n].l+ 1)* num;
a[m][n].max= a[m][n].min= num;
a[m][n].ad= 0; //清除此节点之前的ad
return;
}
pushdown(m, n);
int mid= (a[m][n].l + a[m][n].r)/ 2;
if(r<= mid)
update(l, r, m, 2*n, num);
else if(l> mid)
update(l, r, m, 2*n+1, num);
else
{
update(l, mid, m, 2*n, num);
update(mid+1, r, m, 2*n+1, num);
}
a[m][n].sum= a[m][2*n].sum + a[m][2*n+1].sum;
a[m][n].max= Max(a[m][2*n].max, a[m][2*n+1].max);
a[m][n].min= Min(a[m][2*n].min, a[m][2*n+1].min);
}
int Qsum(int l, int r, int m, int n, int sum)//sum是所有祖先ad之和
{
// printf("%d %d %d %d %d\n",l ,r, m, n, sum);
if(a[m][n].l<= l && a[m][n].r>= r)
{
if(a[m][n].l== l && a[m][n].r== r)
return a[m][n].sum + (r- l +1)* sum;// 如果查询到[L,R]节点
else if(a[m][n].set) // 祖先节点中有sz,则不再往下查询
return (r- l + 1)* (a[m][n].set + a[m][n].ad + sum);
sum+= a[m][n].ad;
}
int mid= (a[m][n].l + a[m][n].r)/ 2;
if(r<= mid)
return Qsum(l, r, m, 2*n, sum);
else if(l> mid)
return Qsum(l, r, m, 2*n+1, sum);
else
return Qsum(l, mid, m, 2*n, sum) + Qsum(mid+1, r, m, 2*n+1, sum);
}
int Qmax(int l, int r, int m, int n, int sum)
{
if(a[m][n].l<= l && a[m][n].r >= r)
{
if((a[m][n].set)||(a[m][n].l == l && a[m][n].r == r))
return a[m][n].max+ sum;
sum+= a[m][n].ad;
}
int mid= (a[m][n].l + a[m][n].r)/ 2;
if(r<= mid)
return Qmax(l, r, m, 2*n, sum);
else if(l> mid)
return Qmax(l, r, m, 2*n+1, sum);
else
return Max( Qmax(l, mid, m, 2*n, sum), Qmax(mid+1, r, m, 2*n+1, sum) );
}
int Qmin(int l, int r, int m, int n, int sum)
{
if(a[m][n].l<= l && a[m][n].r >= r)
{
if((a[m][n].set)||(a[m][n].l == l && a[m][n].r == r))
return a[m][n].min+ sum;
sum+= a[m][n].ad;
}
int mid= (a[m][n].l + a[m][n].r)/ 2;
if(r<= mid)
return Qmin(l, r, m, 2*n, sum);
else if(l> mid)
return Qmin(l, r, m, 2*n+1, sum);
else
return Min( Qmin(l, mid, m, 2*n, sum), Qmin(mid+1, r, m, 2*n+1, sum) );
}
int main()
{
int r, c, m;
while(scanf("%d %d %d",&r,&c,&m)!=EOF)
{
for(int i= 1; i<= r; i++)
Build(1, c, i, 1);
while(m--)
{
int id;
scanf("%d",&id);
int xx, yy, x2, y2, num;
if(id== 1)
{
scanf("%d %d %d %d %d",&xx,&yy,&x2,&y2,&num);
for(int i= xx; i<= x2; i++)
Add(yy, y2, i, 1, num);
}
else if(id== 2)
{
scanf("%d %d %d %d %d",&xx, &yy, &x2, &y2, &num);
for(int i= xx; i<= x2; i++)
update(yy, y2, i, 1, num);
}
else if(id== 3)
{
scanf("%d %d %d %d",&xx, &yy, &x2, &y2);
int _sum,_max,_min;
_sum= Qsum(yy, y2, xx, 1, 0);
_max= Qmax(yy, y2, xx, 1, 0);
_min= Qmin(yy, y2, xx, 1, 0);
for(int i= xx+1; i<= x2; i++)
{
_sum+= Qsum(yy, y2, i, 1, 0);
_max= Max(_max, Qmax(yy, y2, i, 1, 0));
_min= Min(_min, Qmin(yy, y2, i, 1, 0));
}
printf("%d %d %d\n",_sum, _min ,_max);
}
}
}
return 0;
}