想必你一定会用线段树维护等差数列吧?让我们来看看它的升级版。
请你维护一个长度为5×10^5的数组,一开始数组中每个元素都为0,要求支持以下两个操作:
1、区间[l,r]加自然数的平方数组,即al+=1,al+1+=4,al+2+=9,al+3+=16...ar+=(r−l+1)∗(r−l+1)
2、区间[l,r]查询区间和mod 10^9+7
输入描述:
第一行输入n,m,(1≤n,m≤5×105)n,m。 接下来m行,对于每行,先读入一个整数q。 当q的值为1时,还需读入两个整l,r,(1≤l≤r≤n)表示需要对区间[l,r]进行操作,让第一个元素加1,第二个元素加4,第三个元素加9以此类推。 当q的值为2时,还需读入两个整数l,r(1≤l≤r≤n)表示查询l到r的元素和
输出描述:
对于每一个q=2,输出一行一个非负整数,表示l到r的区间和mod 10^9+7。
示例1
输入
4 4 2 1 4 1 1 4 1 3 4 2 1 4
输出
0 35
示例2
输入
10 6 1 1 6 1 8 9 1 3 6 2 1 10 1 1 10 2 1 10
输出
126 511
[l,r]添加平方数列
对于任意位置x属于[l,r]
增加的值应当是( x − ( l − 1 ) ) ^2
展开 :x^2 - 2(l-1)x +(l-1)^2
维护6个系数,分开来求
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N=5e5+10;
const int mod=1e9+7;
int n,m;
struct node
{
int sum0,sum1,sum2,lz0,lz1,lz2;
} t[N*4];
void pushdown(int i,int l,int r)
{
if(t[i].lz0)
{
int k=t[i].lz0;
int mid=(l+r)>>1;
t[i<<1].sum0+=k*(mid-l+1)%mod;
t[i<<1].sum0%=mod;
t[i<<1|1].sum0+=k*(r-mid)%mod;
t[i<<1|1].sum0%=mod;
t[i<<1].lz0+=k;
t[i<<1].lz0%=mod;
t[i<<1|1].lz0+=k;
t[i<<1|1].lz0%=mod;
t[i].lz0=0;
}
if(t[i].lz1)
{
int k=t[i].lz1;
int mid=(l+r)>>1;
t[i<<1].sum1+=k*((mid+l)*(mid-l+1)/2%mod)%mod;
t[i<<1].sum1%=mod;
t[i<<1|1].sum1+=k*((r+mid+1)*(r-mid)/2%mod)%mod;
t[i<<1|1].sum1%=mod;
t[i<<1].lz1+=k;
t[i<<1].lz1%=mod;
t[i<<1|1].lz1+=k;
t[i<<1|1].lz1%=mod;
t[i].lz1=0;
}
if(t[i].lz2)
{
int k=t[i].lz2;
int mid=(l+r)>>1;
t[i<<1].sum2+=k*((mid*(mid+1)/2*(2*mid+1)/3%mod)-((l-1)*((l-1)+1)/2*(2*(l-1)+1)/3%mod)+mod)%mod%mod;
t[i<<1].sum2%=mod;
t[i<<1|1].sum2+=k*((r*(r+1)/2*(2*r+1)/3%mod)-((mid+1-1)*((mid+1-1)+1)/2*(2*(mid+1-1)+1)/3%mod)+mod)%mod%mod;
t[i<<1|1].sum2%=mod;
t[i<<1].lz2+=k;
t[i<<1].lz2%=mod;
t[i<<1|1].lz2+=k;
t[i<<1|1].lz2%=mod;
t[i].lz2=0;
}
}
void build(int i,int l,int r)
{
t[i].lz0=t[i].lz1=t[i].lz2=0;
if(l==r)
{
t[i].sum0=t[i].sum1=t[i].sum2=0;
return ;
}
int mid=(l+r)>>1;
build(i<<1,l,mid);
build(i<<1|1,mid+1,r);
//pushup(rt);
}
void update0(int rt,int l,int r,int L,int R,int k)
{
if(L<=l&&r<=R)
{
t[rt].sum0+=k*(r-l+1)%mod;
t[rt].sum0%=mod;
t[rt].lz0+=k;
t[rt].lz0%=mod;
return ;
}
pushdown(rt,l,r);
int mid=(l+r)>>1;
if(L<=mid) update0(rt<<1,l,mid,L,R,k);
if(R>mid)update0(rt<<1|1,mid+1,r,L,R,k);
t[rt].sum0=(t[rt<<1].sum0+t[rt<<1|1].sum0)%mod;
return;
}
void update1(int rt,int l,int r,int L,int R,int k)
{
if(L<=l&&r<=R)
{
t[rt].sum1+=k*((r+l)*(r-l+1)/2%mod)%mod;
t[rt].sum1%=mod;
t[rt].lz1+=k;
t[rt].lz1%=mod;
return;
}
pushdown(rt,l,r);
int mid=(l+r)>>1;
if(L<=mid) update1(rt<<1,l,mid,L,R,k);
if(R>mid)update1(rt<<1|1,mid+1,r,L,R,k);
t[rt].sum1=(t[rt<<1].sum1+t[rt<<1|1].sum1)%mod;
return;
}
void update2(int rt,int l,int r,int L,int R,int k)
{
if(L<=l&&r<=R)
{
t[rt].sum2+=k*((r*(r+1)/2*(2*r+1)/3%mod)-((l-1)*((l-1)+1)/2*(2*(l-1)+1)/3%mod)+mod)%mod%mod;
t[rt].sum2%=mod;
t[rt].lz2+=k;
t[rt].lz2%=mod;
return;
}
pushdown(rt,l,r);
int mid=(l+r)>>1;
if(L<=mid) update2(rt<<1,l,mid,L,R,k);
if(R>mid)update2(rt<<1|1,mid+1,r,L,R,k);
t[rt].sum2=(t[rt<<1].sum2+t[rt<<1|1].sum2)%mod;
return;
}
int query0(int rt,int l,int r,int L,int R)
{
if(L<=l&&R>=r)
{
return t[rt].sum0;
}
pushdown(rt,l,r);
int mid=(l+r)>>1;
int ans=0;
if(mid >= R) ans=ans+query0(rt<<1, l, mid, L, R),ans%=mod;
else if(mid < L) ans=ans+query0(rt<<1|1, mid + 1, r, L, R),ans%=mod;
else
{
ans=ans+query0(rt<<1, l, mid, L, mid)+query0(rt<<1|1, mid + 1, r, mid + 1, R),ans%=mod;
}
return ans;
}
int query1(int rt,int l,int r,int L,int R)
{
if(L<=l&&R>=r)
{
return t[rt].sum1;
}
pushdown(rt,l,r);
int mid=(l+r)>>1;
int ans=0;
if(mid >= R) ans=ans+query1(rt<<1, l, mid, L, R),ans%=mod;
else if(mid < L) ans=ans+query1(rt<<1|1, mid + 1, r, L, R),ans%=mod;
else
{
ans=ans+query1(rt<<1, l, mid, L, mid)+query1(rt<<1|1, mid + 1, r, mid + 1, R),ans%=mod;
}
return ans;
}
int query2(int rt,int l,int r,int L,int R)
{
if(L<=l&&R>=r)
{
return t[rt].sum2;
}
pushdown(rt,l,r);
int mid=(l+r)>>1;
int ans=0;
if(mid >= R) ans=ans+query2(rt<<1, l, mid, L, R),ans%=mod;
else if(mid < L) ans=ans+query2(rt<<1|1, mid + 1, r, L, R),ans%=mod;
else
{
ans=ans+query2(rt<<1, l, mid, L, mid)+query2(rt<<1|1, mid + 1, r, mid + 1, R),ans%=mod;
}
return ans;
}
signed main()
{
cin>>n>>m;
build(1,1,n);
for(int i=1; i<=m; i++)
{
int op;
cin>>op;
if(op==1)
{
int l,r;
cin>>l>>r;
update0(1,1,n,l,r,(l-1)*(l-1)%mod);
update1(1,1,n,l,r,(-2*(l-1)%mod+mod)%mod);
update2(1,1,n,l,r,1);
}
else
{
int l,r;
cin>>l>>r;
cout<<(query0(1,1,n,l,r)%mod+query1(1,1,n,l,r)%mod+query2(1,1,n,l,r)%mod)%mod<<"\n";
}
}
return 0;
}