题目链接
题目解法
考虑在值域上的问题:有多少段区间,对应在排列上不超过
2
2
2 段
肯定需要枚举一个端点,另一个快速算出,考虑枚举值域区间右端点
r
r
r,计算
l
l
l
可以发现,新增一个数对应在排列上的区间有 3 种不同的情况
- 新增一个段
- 合并 2 个段
- 和左边或右边相连,段数不变
这三种操作对应的值域区间范围不难得出,然后线段树维护即可
时间复杂度
O
(
n
l
o
g
n
)
O(nlogn)
O(nlogn)
这道题的一个启发是:添加一个数,可以快速维护出以这个数为右端点的所有区间对应在一个数列上的连通段的个数
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N=300100;
int n,col[N],pos[N];
struct Node{
int mn1,mn2,sum1,sum2,tag;
}seg[N<<2];
struct SegmentTree{
void build(int l,int r,int x){
seg[x].tag=0,seg[x].mn1=0,seg[x].sum1=r-l+1,seg[x].mn2=1e9,seg[x].sum2=0;
if(l==r) return;
int mid=(l+r)>>1;
build(l,mid,x<<1),build(mid+1,r,x<<1^1);
}
Node merge(Node lc,Node rc){
Node ret;ret.sum1=ret.sum2=ret.tag=0;
ret.mn1=min(lc.mn1,rc.mn1);
ret.mn2=min(lc.mn2,rc.mn2);
if(lc.mn1>ret.mn1) ret.mn2=min(ret.mn2,lc.mn1);
if(rc.mn1>ret.mn1) ret.mn2=min(ret.mn2,rc.mn1);
if(lc.mn1==ret.mn1) ret.sum1+=lc.sum1;
if(rc.mn1==ret.mn1) ret.sum1+=rc.sum1;
if(lc.mn1==ret.mn2) ret.sum2+=lc.sum1;
if(rc.mn1==ret.mn2) ret.sum2+=rc.sum1;
if(lc.mn2==ret.mn2) ret.sum2+=lc.sum2;
if(rc.mn2==ret.mn2) ret.sum2+=rc.sum2;
return ret;
}
void pushdown(int x){
int D=seg[x].tag;
// cout<<"Delta : "<<D<<'\n';
seg[x<<1].mn1+=D,seg[x<<1].mn2+=D,seg[x<<1].tag+=D;
seg[x<<1^1].mn1+=D,seg[x<<1^1].mn2+=D,seg[x<<1^1].tag+=D;
seg[x].tag=0;
}
void modify(int l,int r,int x,int L,int R,int v){
if(L<=l&&r<=R){
// cout<<l<<' '<<r<<' '<<seg[x].mn1<<' '<<v<<'\n';
seg[x].mn1+=v,seg[x].mn2+=v,seg[x].tag+=v;
// cout<<l<<' '<<r<<' '<<seg[x].mn1<<' '<<v<<'\n';
return;
}
pushdown(x);
int mid=(l+r)>>1;
if(mid>=L) modify(l,mid,x<<1,L,R,v);
if(mid<R) modify(mid+1,r,x<<1^1,L,R,v);
seg[x]=merge(seg[x<<1],seg[x<<1^1]);
// cout<<l<<' '<<r<<' '<<seg[x<<1].mn1<<' '<<seg[x<<1^1].mn1<<' '<<seg[x].mn1<<'\n';
}
Node query(int l,int r,int x,int L,int R){
if(L<=l&&r<=R) return seg[x];
pushdown(x);
int mid=(l+r)>>1;
if(mid>=L&&mid<R) return merge(query(l,mid,x<<1,L,R),query(mid+1,r,x<<1^1,L,R));
if(mid>=L) return query(l,mid,x<<1,L,R);
return query(mid+1,r,x<<1^1,L,R);
}
}sg;
inline int read(){
int FF=0,RR=1;
char ch=getchar();
for(;!isdigit(ch);ch=getchar()) if(ch=='-') RR=-1;
for(;isdigit(ch);ch=getchar()) FF=(FF<<1)+(FF<<3)+ch-48;
return FF*RR;
}
int main(){
n=read();
for(int i=1;i<=n;i++) col[i]=read(),pos[col[i]]=i;
sg.build(1,n,1);
LL ans=0;
for(int i=1;i<=n;i++){
int L=col[pos[i]-1],R=col[pos[i]+1];
//add a block -> L and R not added
int bound=max(L>i?1:L+1,R>i?1:R+1);
//opt : [bound,i] +1
sg.modify(1,n,1,bound,i,1);
//merge two blocks
bound=min(L,R);
if(max(L,R)<i&&bound>0){
//opt : [1,bound] -1
sg.modify(1,n,1,1,bound,-1);
}
//merge with L or R
//[min(L,R)+1,max(L,R)] do not change
if(i>1){
Node t=sg.query(1,n,1,1,i-1);
if(t.mn1<=2) ans+=t.sum1;
if(t.mn2<=2) ans+=t.sum2;
}
}
printf("%lld",ans);
return 0;
}
/*
10
5 6 1 3 4 9 10 2 8 7
*/