我的第一道题解就写它吧。
维护区间和+花式修改。用线段树或者树状数组可以解决,但是我没怎么写过树状数组。
维护和的操作直接把左右子树和加起来。
重点是修改。
去年刚学完线段树刷完数列操作a,b,c后看道这题就弃了。现在知道了关键就是推式子,跟HAOI2012高速公路是一个套路的。
线段树就是用于维护区间的,而且因为延迟标记的存在,所以我们先考虑区间增量。
对于区间 [l,r] 增量为 Σri=l(i-L)*x。L是总的修改范围,l,r,是线段树中节点的范围,第一次就是因为没注意这两者的关系,导致公式错误连样例都过不了。
提公因式x,Σ里是等差数列直接求和之后相乘就算出增量了,这时候要考虑如何打lazy标记。
思想也是提公因式。
Σri=l ( i-L ) * x = ( Σri=li ) * x - ( Σri=l1 ) * L * x
设 A = Σri=li,B = Σri=l1 ,可以发现对于线段树中的每段区间A和B的值是固定的,这样一来我们只需要累计每次修改的x以及L*x的值就可以顺利下传延迟修改了。
这道题就顺利解决了。
// q.c
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
using namespace std;
typedef long long LL;
const int M=300000+10;
const int mod=(int)1e9+7;
struct Node {
int l,r,sum,s1,s2;
Node():l(0),r(0),sum(0),s1(0),s2(0) {}
};
struct SegmentTree {
int root; Node nd[M<<2];
SegmentTree():root(1) {}
void update(int o) {
nd[o].sum=((LL)nd[o<<1].sum+nd[o<<1^1].sum)%mod;
}
void pushdown(int o) {
Node &p=nd[o],&lc=nd[o<<1],&rc=nd[o<<1^1];
lc.sum=(lc.sum+(LL)p.s1*(lc.r+lc.l)*(lc.r-lc.l+1)/2)%mod;
lc.sum=(lc.sum-(LL)p.s2*(lc.r-lc.l+1)%mod+mod)%mod;
lc.s1=(lc.s1+p.s1)%mod;
lc.s2=(lc.s2+p.s2)%mod;
rc.sum=(rc.sum+(LL)p.s1*(rc.r+rc.l)*(rc.r-rc.l+1)/2)%mod;
rc.sum=(rc.sum-(LL)p.s2*(rc.r-rc.l+1)%mod+mod)%mod;
rc.s1=(rc.s1+p.s1)%mod;
rc.s2=(rc.s2+p.s2)%mod;
p.s1=p.s2=0;
}
void build(int o,int l,int r) {
nd[o].l=l,nd[o].r=r;
if(l!=r) {
int mid=(l+r)>>1;
build(o<<1,l,mid);
build(o<<1^1,mid+1,r);
}
}
void add(int o,int l,int r,int x) {
Node &p=nd[o];
if(l<=p.l&&p.r<=r) {
p.sum=(p.sum+(LL)x*(p.r-l+p.l-l)*(p.r-p.l+1)/2)%mod;
p.s1=(p.s1+x)%mod;
p.s2=(p.s2+(LL)l*x)%mod;
} else {
if(p.s1||p.s1) pushdown(o);
int mid=(p.l+p.r)>>1;
if(l<=mid) add(o<<1,l,r,x);
if(r>mid) add(o<<1^1,l,r,x);
update(o);
}
}
int query(int o,int l,int r) {
Node p=nd[o];
if(l<=p.l&&p.r<=r) return p.sum;
else {
if(p.s1||p.s2) pushdown(o);
int mid=(p.l+p.r)>>1,ans=0;
if(l<=mid) ans=((LL)ans+query(o<<1,l,r))%mod;
if(r>mid) ans=((LL)ans+query(o<<1^1,l,r))%mod;
return ans;
}
}
}t;
int n,m;
int main() {
freopen("segment.in","r",stdin);
freopen("segment.out","w",stdout);
scanf("%d%d",&n,&m);
t.build(t.root,1,n);
int opt,l,r,x;
for(int i=1;i<=m;i++) {
scanf("%d%d%d",&opt,&l,&r);
if(opt) scanf("%d",&x),t.add(t.root,l,r,x);
else printf("%d\n",t.query(t.root,l,r));
}
return 0;
}
公式真的好难打啊。