稍复杂一点点的线段树好题。
简化题意
对于给出的一个数列,进行以下两种操作:
- 操作1:将数列的第 p p p 个数字修改为 x x x。
- 操作2:查询区间 [ l , r ] [l,r] [l,r] 中第二大的数的个数。
思路
上来第一时间想到了线段树,一开始想到的还是正解,后面有一步想错了,赛场没切(悲)。
思路是记录区间的最大值、次大值、最大值个数和次大值个数。下面稍微讲一下实现方法:
-
该区间的最大值等于其左区间的最大值和右区间最大值的较大值。
-
可能有人想,那么该区间的次大值是不是就是其左区间的最大值和右区间最大值的较小值呢?
当然不是,
我就是这里掉坑的,对于一个区间有可能其最大值和次大值都同时存在于左区间或右区间,例如左区间最大值和次大值为 { 2 , 1 } \{2,1\} {2,1},右区间最大值和次大值为 { 4 , 3 } \{4,3\} {4,3},那么这里的最大值和次大值都存在于右区间, 不要想当然导致错误。并且还有一种特殊情况,例如左区间最大值和次大值为 { 4 , 1 } \{4,1\} {4,1},右区间最大值和次大值为 { 4 , 2 } \{4,2\} {4,2},那么它们的次大值是 2 2 2,这里两个区间都存在最大值,注意最大值和次大值不能相等,需要严格次大。如果你的最大值和次大值相等,请检查你用的方法是否有误。
这里我用的方法是直接把它们全部取出来存到一个数组内,取出两个不同的数作为该区间最终的最大值和次大值,比较偷懒但是好写。可能有人会想到这里如果不存在次大值怎么办,没关系,直接将 0 0 0 作为次大值即可,同时满足输出要求并且不会影响后续操作。
-
最大值个数的存储我比较暴力。直接检查它是否与两个子区间的最大值和次大值相等,如果相等,那么将最大值的个数加上子区间最大值的个数或次大值的个数即可。
-
次大值个数的记录方法同最大值个数的一样。
这里在上传和查询时可能有点复杂,注意实现细节。
对于更改操作,直接进行线段树单点修改操作即可,实现非常简单。
整体时间复杂度为 O ( q log n ) O(q\log{n}) O(qlogn)。
代码
#include<bits/stdc++.h>
using namespace std;
#define ll long long
ll n,q,a[1001000],l,r,op;
struct node
{
ll l,r,mx,mx2,mxcnt,mx2cnt;
node(){mx=mx2=mxcnt=mx2cnt=0;}
}tr[1001000];
ll checkmx(ll rt,ll rt2)
{
ll sum=0;
if(tr[rt].mx==tr[rt2].mx) sum+=tr[rt2].mxcnt;
if(tr[rt].mx==tr[rt2].mx2) sum+=tr[rt2].mx2cnt;
return sum;
}
ll checkmx2(ll rt,ll rt2)
{
ll sum=0;
if(tr[rt].mx2==tr[rt2].mx) sum+=tr[rt2].mxcnt;
if(tr[rt].mx2==tr[rt2].mx2) sum+=tr[rt2].mx2cnt;
return sum;
}
void pushup(ll r)
{
tr[r].mxcnt=tr[r].mx2cnt=0;
tr[r].mx=max(tr[r*2].mx,tr[r*2+1].mx);
ll o[10]={0,tr[r*2].mx,tr[r*2].mx2,tr[r*2+1].mx,tr[r*2+1].mx2};
sort(o+1,o+1+4,greater<ll>());
if(o[2]!=tr[r].mx) tr[r].mx2=o[2];
else if(o[3]!=tr[r].mx) tr[r].mx2=o[3];
else if(o[4]!=tr[r].mx) tr[r].mx2=o[4];
tr[r].mxcnt+=checkmx(r,r*2),tr[r].mx2cnt+=checkmx2(r,r*2);
tr[r].mxcnt+=checkmx(r,r*2+1),tr[r].mx2cnt+=checkmx2(r,r*2+1);
}
void build(ll rt,ll l,ll r)
{
tr[rt].l=l,tr[rt].r=r;
if(l==r)
{
tr[rt].mx=a[l];
tr[rt].mxcnt=1;
return;
}
ll mid=(l+r)/2;
build(rt*2,l,mid);
build(rt*2+1,mid+1,r);
pushup(rt);
}
void update(ll rt,ll p,ll val)
{
if(tr[rt].l==tr[rt].r)
{
tr[rt].mx=val;
tr[rt].mxcnt=1;
return;
}
ll mid=(tr[rt].l+tr[rt].r)/2;
if(p<=mid) update(rt*2,p,val);
else update(rt*2+1,p,val);
pushup(rt);
}
node query(ll rt,ll l,ll r)
{
if(tr[rt].l>r||tr[rt].r<l) return node();
if(tr[rt].l>=l&&tr[rt].r<=r) return tr[rt];
ll mid=(tr[rt].l+tr[rt].r)/2;
if(l>mid) return query(rt*2+1,l,r);
if(r<=mid) return query(rt*2,l,r);
node tmp,a=query(rt*2,l,r),b=query(rt*2+1,l,r);
tmp.mx=max(a.mx,b.mx);
ll o[10]={0,a.mx,a.mx2,b.mx,b.mx2};
sort(o+1,o+1+4,greater<ll>());
if(o[2]!=tmp.mx) tmp.mx2=o[2];
else if(o[3]!=tmp.mx) tmp.mx2=o[3];
else if(o[4]!=tmp.mx) tmp.mx2=o[4];
tmp.l=a.l,tmp.r=b.r;
if(tmp.mx==a.mx) tmp.mxcnt+=a.mxcnt;
if(tmp.mx==a.mx2) tmp.mxcnt+=a.mx2cnt;
if(tmp.mx2==a.mx) tmp.mx2cnt+=a.mxcnt;
if(tmp.mx2==a.mx2) tmp.mx2cnt+=a.mx2cnt;
if(tmp.mx==b.mx) tmp.mxcnt+=b.mxcnt;
if(tmp.mx==b.mx2) tmp.mxcnt+=b.mx2cnt;
if(tmp.mx2==b.mx) tmp.mx2cnt+=b.mxcnt;
if(tmp.mx2==b.mx2) tmp.mx2cnt+=b.mx2cnt;
return tmp;
}
int main()
{
cin>>n>>q;
for(int i=1;i<=n;i++)
{
scanf("%lld",&a[i]);
}
build(1,1,n);
while(q--)
{
scanf("%lld%lld%lld",&op,&l,&r);
if(op==1) update(1,l,r);
else cout<<query(1,l,r).mx2cnt<<endl;
}
return 0;
}