题意:
你有一个长度为
n
n
n的排列,对于一对下标
i
,
j
(
i
<
j
)
i,j(i<j)
i,j(i<j),如果它们之间没有比
a
i
a_i
ai和
a
j
a_j
aj大的数字,那么贡献为
p
1
p_1
p1,如果存在一个位置
k
k
k,满足
a
i
<
a
k
<
a
j
a_i<a_k<a_j
ai<ak<aj或
a
j
<
a
k
<
a
i
a_j<a_k<a_i
aj<ak<ai,则贡献为
p
2
p_2
p2,其他情况没有贡献。有
m
m
m次询问,每次询问给你一个区间
[
l
,
r
]
[l,r]
[l,r],问你区间内所有下标二元组
i
,
j
(
l
<
=
i
<
j
<
=
r
)
i,j(l<=i<j<=r)
i,j(l<=i<j<=r)的贡献之和。
n
,
m
<
=
2
e
5
n,m<=2e5
n,m<=2e5,保证答案不爆long long
题解:
高中时就听说过这个题不错,有人推荐过,但是我当时没去做。然后现在去做的时候翻了翻那个神仙的博客,发现我和他这个题的做法还不太一样QAQ
先说一下我的做法吧。
对于这种区间询问的问题,首先考虑线段树直接维护,看需要维护哪些信息才能合并和维护答案。我们发现想要合并答案,似乎需要对于每个区间,维护两个单调栈,分别是左侧开始的和右侧开始的栈,每次区间合并我们要把单调栈合并。那么我们就可以知道这个区间信息没法直接维护和合并了。
对于这种区间询问的问题,我们的一个常见想法是把询问全部离线,然后对于每个位置,回答所有以这个位置为右端点的询问。这样我们一边从左向右枚举右端点一边更新答案和回答询问。
首先先考虑第一类贡献。对于当前右端点,它作为二元组中的 j j j时,所有可行的 i i i满足 i , j i,j i,j之间没有更大的数。这个东西我们维护一个单调栈,越靠近栈顶的元素下标越大,值越小。于是我们在加入一个数时所有弹栈被弹出的位置都符合要求,如果弹玩栈不为空,那么栈顶也符合要求。我们对于一个区间的信息,我们用线段树维护就行,每次找到符合的位置就单点加,对于询问我们直接区间和就行,因为我们是按照右端点顺序回答询问,所以不用主席树前缀相减什么的也能正确维护。
再说一下第二类贡献。这个比第一类麻烦一些。我们还是可以发现这个和单调栈中的元素有关。每次加入元素时,如果当前元素能弹栈,那么栈顶第一个元素与栈顶第二个元素之间的位置作为左端点,当前位置作为右端点一定是可行的。但是这样只考虑了当前位置是两个端点中较大的答案。其实较小的也不难做,我们只需要反过来从后向前再做一遍就行,做的时候我们把询问按照左端点从大到小排序来做。
这样我的做法就介绍完了。
再说一下神仙y_immortal的做法。他的做法我不是很确定是不是和我下面说的完全一样的思路,但是大概思路应该是差不多的。
他的做法应该是用单调栈正反扫两次,对于每个位置算出前一个和后一个比它大的位置。对于第一种情况,所有合法区间就是每个位置对应的前一个和后一个比它大的位置形成的区间。对于第二种情况,就是前面比它大的那个数做左端点,它后一个位置到后一个比它大的位置的前一个位置做右端点都是合法的;它后面比它大的第一个位置做右端点,前面比它大的第一个位置的下一个到它的前一个位置做合法的左端点。换句话说,就是它作为那个区间中的最大值,我们枚举了区间最大值,但是我们其实做的也并不完全是这样,而是枚举了每个位置做最大值,看哪个是对应的左右端点。但是我没仔细想怎么具体算,但是可能也要在算的时候正反扫两遍。
代码:
#include <bits/stdc++.h>
using namespace std;
int n,m,a[200010],sta[200010],tp,id[200010];
long long p1,p2;
struct node
{
int l,r,id;
long long ans;
}q[200010];
struct tree
{
int l,r,tag2;
long long num1,num2;
}tr[800010];
inline int cmp(node x,node y)
{
return x.r<y.r;
}
inline int cmp2(node x,node y)
{
return x.l>y.l;
}
inline int cmp3(node x,node y)
{
return x.id<y.id;
}
inline void build(int rt,int l,int r)
{
tr[rt].l=l;
tr[rt].r=r;
tr[rt].num1=0;
tr[rt].num2=0;
tr[rt].tag2=0;
if(l==r)
return;
int mid=(l+r)>>1;
build(rt<<1,l,mid);
build(rt<<1|1,mid+1,r);
}
inline void update(int rt,int x)
{
int l=tr[rt].l,r=tr[rt].r;
if(l==r)
{
tr[rt].num1++;
return;
}
int mid=(l+r)>>1;
if(x<=mid)
update(rt<<1,x);
else
update(rt<<1|1,x);
tr[rt].num1++;
}
inline void pushdown(int rt)
{
if(tr[rt].tag2)
{
tr[rt<<1].tag2+=tr[rt].tag2;
tr[rt<<1].num2+=tr[rt].tag2*(tr[rt<<1].r-tr[rt<<1].l+1);
tr[rt<<1|1].tag2+=tr[rt].tag2;
tr[rt<<1|1].num2+=tr[rt].tag2*(tr[rt<<1|1].r-tr[rt<<1|1].l+1);
tr[rt].tag2=0;
}
}
inline void update2(int rt,int le,int ri)
{
int l=tr[rt].l,r=tr[rt].r;
if(le<=l&&r<=ri)
{
tr[rt].tag2++;
tr[rt].num2+=r-l+1;
return;
}
pushdown(rt);
int mid=(l+r)>>1;
if(le<=mid)
update2(rt<<1,le,ri);
if(mid+1<=ri)
update2(rt<<1|1,le,ri);
tr[rt].num2=tr[rt<<1].num2+tr[rt<<1|1].num2;
}
inline long long query(int rt,int le,int ri)
{
int l=tr[rt].l,r=tr[rt].r;
if(le<=l&&r<=ri)
return tr[rt].num1;
int mid=(l+r)>>1;
long long res=0;
if(le<=mid)
res+=query(rt<<1,le,ri);
if(mid+1<=ri)
res+=query(rt<<1|1,le,ri);
return res;
}
inline long long query2(int rt,int le,int ri)
{
int l=tr[rt].l,r=tr[rt].r;
if(le<=l&&r<=ri)
return tr[rt].num2;
pushdown(rt);
int mid=(l+r)>>1;
long long res=0;
if(le<=mid)
res+=query2(rt<<1,le,ri);
if(mid+1<=ri)
res+=query2(rt<<1|1,le,ri);
return res;
}
int main()
{
scanf("%d%d%lld%lld",&n,&m,&p1,&p2);
for(int i=1;i<=n;++i)
scanf("%d",&a[i]);
for(int i=1;i<=m;++i)
{
scanf("%d%d",&q[i].l,&q[i].r);
q[i].id=i;
}
sort(q+1,q+m+1,cmp);
build(1,1,n);
int cur=1;
sta[0]=2e9;
id[0]=0;
for(int i=1;i<=n;++i)
{
while(a[i]>sta[tp])
{
update(1,id[tp]);
if(id[tp-1]+1<=id[tp]-1)
update2(1,id[tp-1]+1,id[tp]-1);
--tp;
}
if(tp!=0)
update(1,id[tp]);
sta[++tp]=a[i];
id[tp]=i;
while(q[cur].r<=i&&cur<=m)
{
q[cur].ans+=query(1,q[cur].l,q[cur].r)*p1;
q[cur].ans+=query2(1,q[cur].l,q[cur].r)*p2;
++cur;
}
}
build(1,1,n);
tp=0;
id[0]=n+1;
sort(q+1,q+m+1,cmp2);
cur=1;
for(int i=n;i>=1;--i)
{
while(a[i]>sta[tp])
{
if(id[tp-1]-1>=id[tp]+1)
update2(1,id[tp]+1,id[tp-1]-1);
--tp;
}
sta[++tp]=a[i];
id[tp]=i;
while(q[cur].l>=i&&cur<=m)
{
q[cur].ans+=query2(1,q[cur].l,q[cur].r)*p2;
++cur;
}
}
sort(q+1,q+m+1,cmp3);
for(int i=1;i<=m;++i)
printf("%lld\n",q[i].ans);
return 0;
}