题目大意
有两个长度为
n
n
n的序列
a
a
a和
b
b
b,有
q
q
q次询问,每次询问给出
l
,
r
l,r
l,r,求
∑
i
=
l
r
∑
j
=
i
+
1
r
(
max
k
=
i
j
a
k
)
×
(
max
l
=
i
j
b
l
)
\sum\limits_{i=l}^r\sum\limits_{j=i+1}^r(\max\limits_{k=i}^ja_k)\times(\max\limits_{l=i}^jb_l)
i=l∑rj=i+1∑r(k=imaxjak)×(l=imaxjbl)
题解
我们可以离线做,对所有询问按 r r r从小到大排序。假设当前考虑到 r r r,令 x i = max j = i r a j x_i=\max\limits_{j=i}^ra_j xi=j=imaxraj, y i = max j = i r b j y_i=\max\limits_{j=i}^rb_j yi=j=imaxrbj,我们可以用线段树维护 v i = ∑ j = i r x j y j v_i=\sum\limits_{j=i}^rx_jy_j vi=j=i∑rxjyj,那么对于每一次询问,答案就是 ∑ i = l r v i \sum\limits_{i=l}^rv_i i=l∑rvi。
可以先用单调栈处理出 a , b a,b a,b数组中每个元素作为最大值向左能覆盖的最大位置,那么对于每一个询问的 r r r,我们将当前右端点不断移动直到到达 r r r点,每一次移动都要对线段树进行以下三个操作:
- 区间更新这一段的 x i x_i xi为 a r a_r ar
- 区间更新这一段的 y i y_i yi为 b r b_r br
- 让下标在 1 1 1到 r r r之间的每一个 v i v_i vi都加上 x i × y i x_i\times y_i xi×yi。
线段树一共要维护四个信息, s s s表示 f i f_i fi的区间和, x y xy xy表示 x i × y i x_i\times y_i xi×yi的区间和, x x x表示 x i x_i xi的区间和, y y y表示 y i y_i yi的区间和。
有六个懒标记, c x , c y cx,cy cx,cy是覆盖标记, l y x y , l y x , l y y , l y v ly_{xy},ly_x,ly_y,ly_v lyxy,lyx,lyy,lyv分别是 ∑ x i y i , ∑ x , ∑ y , ∑ 1 \sum x_iy_i,\sum x,\sum y,\sum 1 ∑xiyi,∑x,∑y,∑1(即区间长度)的增量,也就是增加的倍数。
在下传懒标记时,按照是否有被覆盖,要做一些处理:
- 如果 x , y x,y x,y都被覆盖,则 l y x y ∑ x i y i = l y x y × c x × c y × l e n ly_{xy}\sum x_iy_i=ly_{xy}\times c_x\times c_y\times len lyxy∑xiyi=lyxy×cx×cy×len,将 l y x y × c x × c y ly_{xy}\times c_x\times c_y lyxy×cx×cy加到 l y c ly_c lyc中
- 如果 x x x被覆盖,则 l y x y ∑ x i y i = l y x y × c x × ∑ y i ly_{xy}\sum x_iy_i=ly_{xy}\times c_x\times \sum y_i lyxy∑xiyi=lyxy×cx×∑yi, l y x ∑ x i = l y x × c x × l e n ly_x\sum x_i=ly_x\times c_x\times len lyx∑xi=lyx×cx×len,将 l y x y × c x ly_{xy}\times c_x lyxy×cx加到 l y y ly_y lyy中,将 l y x × c x ly_x\times c_x lyx×cx加到 l y c ly_c lyc中
- 如果 y y y被覆盖,与 x x x被覆盖类似
- 如果都没有被覆盖,则各自相加即可
更新完加操作后,再更新覆盖操作。
对于区间信息的更新,首先 s + = l y x y × ∑ x i y i + l y x × ∑ x i + l y y × ∑ y i + l y c × ∑ 1 s+=ly_{xy}\times \sum x_iy_i+ly_x\times \sum x_i+ly_y\times \sum y_i+ly_c\times \sum 1 s+=lyxy×∑xiyi+lyx×∑xi+lyy×∑yi+lyc×∑1,其余部分要根据被覆盖的情况来更新。
对于取模的问题,用unsigned long long自然溢出即可解决。
可以看代码帮助理解。
时间复杂度为 O ( n log n + q log n ) O(n\log n+q\log n) O(nlogn+qlogn)。
code
#include<bits/stdc++.h>
#define lc k<<1
#define rc k<<1|1
using namespace std;
typedef unsigned long long ULL;
const int N=250005;
int tt,n,q,q1[N],q2[N],l1[N],l2[N];
ULL a[N],b[N],ans[N];
struct gt{
int l,r,id;
}w[N];
struct node{
ULL s,xy,x,y;
}tr[N*4];
struct tag{
ULL cx,cy,xy,x,y,c;
}ly[N*4];
bool cmp(gt ax,gt bx){
return ax.r<bx.r;
}
void up(int k){
tr[k]=(node){tr[lc].s+tr[rc].s,
tr[lc].xy+tr[rc].xy,
tr[lc].x+tr[rc].x,
tr[lc].y+tr[rc].y};
}
void gx(int k,int len,tag t){
tag &v1=ly[k];
if(v1.cx&&v1.cy){
v1.c+=t.xy*v1.cx*v1.cy+t.x*v1.cx+t.y*v1.cy+t.c;
}
else if(v1.cx){
v1.c+=t.x*v1.cx+t.c;
v1.y+=t.xy*v1.cx+t.y;
}
else if(v1.cy){
v1.c+=t.y*v1.cy+t.c;
v1.x+=t.xy*v1.cy+t.x;
}
else{
v1.xy+=t.xy;
v1.x+=t.x;
v1.y+=t.y;
v1.c+=t.c;
}
if(t.cx) v1.cx=t.cx;
if(t.cy) v1.cy=t.cy;
node &v2=tr[k];
v2.s+=t.xy*v2.xy+t.x*v2.x+t.y*v2.y+t.c*len;
if(t.cx&&t.cy){
v2.xy=t.cx*t.cy*len;
v2.x=t.cx*len;
v2.y=t.cy*len;
}
else if(t.cx){
v2.xy=t.cx*v2.y;
v2.x=t.cx*len;
}
else if(t.cy){
v2.xy=t.cy*v2.x;
v2.y=t.cy*len;
}
}
void down(int k,int l,int r){
if(ly[k].cx||ly[k].cy||ly[k].xy||ly[k].x||ly[k].y||ly[k].c){
int mid=l+r>>1;
gx(lc,mid-l+1,ly[k]);
gx(rc,r-mid,ly[k]);
ly[k]=(tag){0,0,0,0,0,0};
}
}
void ch(int k,int l,int r,int x,int y,tag t){
if(l>=x&&r<=y){
gx(k,r-l+1,t);
return;
}
down(k,l,r);
int mid=l+r>>1;
if(x<=mid) ch(lc,l,mid,x,y,t);
if(y>mid) ch(rc,mid+1,r,x,y,t);
up(k);
}
ULL find(int k,int l,int r,int x,int y){
if(l>=x&&r<=y) return tr[k].s;
down(k,l,r);
int mid=l+r>>1;
ULL re=0;
if(x<=mid) re+=find(lc,l,mid,x,y);
if(y>mid) re+=find(rc,mid+1,r,x,y);
return re;
}
int main()
{
scanf("%d%d",&tt,&n);
a[0]=b[0]=1e9;
q1[++q1[0]]=q2[++q2[0]]=0;
for(int i=1;i<=n;i++){
scanf("%llu",&a[i]);
while(q1[0]&&a[i]>a[q1[q1[0]]]) --q1[0];
l1[i]=q1[q1[0]]+1;q1[++q1[0]]=i;
}
for(int i=1;i<=n;i++){
scanf("%llu",&b[i]);
while(q2[0]&&b[i]>b[q2[q2[0]]]) --q2[0];
l2[i]=q2[q2[0]]+1;q2[++q2[0]]=i;
}
scanf("%d",&q);
for(int i=1;i<=q;i++){
scanf("%d%d",&w[i].l,&w[i].r);w[i].id=i;
}
sort(w+1,w+q+1,cmp);
for(int i=1,r=0;i<=q;i++){
while(r<w[i].r){
++r;
ch(1,1,n,l1[r],r,(tag){a[r],0,0,0,0,0});
ch(1,1,n,l2[r],r,(tag){0,b[r],0,0,0,0});
ch(1,1,n,1,r,(tag){0,0,1,0,0,0});
}
ans[w[i].id]=find(1,1,n,w[i].l,w[i].r);
}
for(int i=1;i<=q;i++){
printf("%llu\n",ans[i]);
}
return 0;
}