题目描述
qn姐姐最好了~
qn姐姐给你了一个长度为n的序列还有m次操作让你玩,
1 l r 询问区间[l,r]内的元素和
2 l r 询问区间[l,r]内的元素的平方和
3 l r x 将区间[l,r]内的每一个元素都乘上x
4 l r x 将区间[l,r]内的每一个元素都加上x
输入描述:第一行两个数n,m
接下来一行n个数表示初始序列
就下来m行每行第一个数为操作方法opt,
若opt=1或者opt=2,则之后跟着两个数为l,r
若opt=3或者opt=4,则之后跟着三个数为l,r,x
操作意思为题目描述里说的
输出描述:对于每一个操作1,2,输出一行表示答案
示例1
输入
5 6
1 2 3 4 5
1 1 5
2 1 5
3 1 2 1
4 1 3 2
1 1 4
2 2 3
输出
15
55
16
41
备注:对于100%的数据 n=10000,m=200000 (注意是等于号)
保证所有询问的答案在long long 范围内
PS:比较难的线段树吧,但是这道题可以暴力过,线段树注意数据更新和懒标记下放。
AC代码:#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
const int maxn=1e4+10;
const int mod=1e9+7;
const int inf=1e8;
#define me(a,b) memset(a,b,sizeof(a))
#define lowbit(x) x&(-x)
typedef long long ll;
using namespace std;
ll sum[maxn<<2],num[maxn<<2],add[maxn<<2],mul[maxn<<2];
void init(int n)
{
for(int i=1;i<=n;i++)
sum[i]=num[i]=add[i]=0,mul[i]=1;
}
void up_data(int rt)
{
num[rt]=num[rt<<1]+num[rt<<1|1];
sum[rt]=sum[rt<<1]+sum[rt<<1|1];
}
void build(int l,int r,int rt)
{
if(l==r)
{
scanf("%lld",&num[rt]);
sum[rt]=num[rt]*num[rt];
return ;
}
int mid=(l+r)>>1;
build(l,mid,rt<<1);
build(mid+1,r,rt<<1|1);
up_data(rt);
}
void push_down(int l,int r,int rt)
{
int m=(l+r)>>1;
if(add[rt])
{
sum[rt<<1]+=(m-l+1)*add[rt]*add[rt]+2*add[rt]*num[rt<<1];
sum[rt<<1|1]+=(r-m)*add[rt]*add[rt]+2*add[rt]*num[rt<<1|1];
add[rt<<1]+=add[rt],add[rt<<1|1]+=add[rt];
num[rt<<1]+=add[rt]*(m-l+1),num[rt<<1|1]+=add[rt]*(r-m);
add[rt]=0;
}
if(mul[rt]>1)
{
mul[rt<<1]*=mul[rt],mul[rt<<1|1]*=mul[rt];
add[rt<<1]*=mul[rt],add[rt<<1|1]*=mul[rt];
num[rt<<1]*=mul[rt],num[rt<<1|1]*=mul[rt];
sum[rt<<1]*=mul[rt]*mul[rt],sum[rt<<1|1]*=mul[rt]*mul[rt];
mul[rt]=1;
}
}
void push_data(int flog,int x,int L,int R,int l,int r,int rt)
{
if(L<=l&&R>=r)
{
if(flog==3)
add[rt]*=x,num[rt]*=x,mul[rt]*=x,sum[rt]*=x*x;
else
{
sum[rt]+=(r-l+1)*x*x+2*num[rt]*x;
add[rt]+=x,num[rt]+=(r-l+1)*x;
}
return ;
}
push_down(l,r,rt);
int mid=(l+r)>>1;
if(L<=mid)
push_data(flog,x,L,R,l,mid,rt<<1);
if(R>mid)
push_data(flog,x,L,R,mid+1,r,rt<<1|1);
up_data(rt);
}
ll get_sum(int flog,int L,int R,int l,int r,int rt)
{
if(L<=l&&R>=r)
{
if(flog==1)
return num[rt];
return sum[rt];
}
push_down(l,r,rt);
int mid=(l+r)>>1;
ll s=0;
if(L<=mid)
s+=get_sum(flog,L,R,l,mid,rt<<1);
if(R>mid)
s+=get_sum(flog,L,R,mid+1,r,rt<<1|1);
return s;
}
int main()
{
int n,m;
cin>>n>>m;
init(n);
build(1,n,1);
while(m--)
{
int opt,l,r,k;
scanf("%d",&opt);
if(opt==1||opt==2)
{
scanf("%d%d",&l,&r);
printf("%lld\n",get_sum(opt,l,r,1,n,1));
}
else
{
scanf("%d%d%d",&l,&r,&k);
push_data(opt,k,l,r,1,n,1);
}
}
return 0;
}