题面
题意
给出一个有n个数的排列,现在请你选择两个不重叠的区间,使这两个区间的数的集合可以组成一个公差为1的等差数列,问有几个这样的集合。
做法
我们可以用f[l,r]表示[l,r]中的所有数在给出排列中最少分成几段,这样问题就转化为了求满足
f
[
l
,
r
]
<
=
2
,
l
=
̸
r
f[l,r]<=2,l =\not r
f[l,r]<=2,l≠r的区间个数。
可以考虑枚举区间的右端点,用线段树维护
f
[
1
,
r
]
f[1,r]
f[1,r]-
f
[
r
,
r
]
f[r,r]
f[r,r]的值,这样当
r
+
+
r++
r++后,假设数r这一个数一个区间,那么就相当于对线段树上的区间
[
1
,
r
]
+
1
[1,r]+1
[1,r]+1,如果r在排列中左边的数t比它小,则可以发现区间
f
[
1
,
r
]
f[1,r]
f[1,r]-
f
[
t
,
r
]
f[t,r]
f[t,r]都
−
1
-1
−1,而右边的数也同理,这样只要对线段树维护区间加,区间查询1,2的个数和(就是维护最小值,最小值的个数,次小值,次小值的个数)操作即可。
代码
#include<bits/stdc++.h>
#define ll long long
#define P pair<ll,ll>
#define mp make_pair
#define fi first
#define se second
#define INF 0x3f3f3f3f
#define N 300100
using namespace std;
ll n,ans,tt,num[N],pos[N];
P gg[5];
struct Node
{
ll ls,rs,dw;
P mn,m2;
void add(ll u){dw+=u,mn.fi+=u,m2.fi+=u;}
}node[N<<1];
inline void up(ll now)
{
ll i,L=node[now].ls,R=node[now].rs;
gg[0]=node[L].mn,gg[1]=node[L].m2;
gg[2]=node[R].mn,gg[3]=node[R].m2;
sort(gg,gg+4);
node[now].mn=gg[0];
for(i=1;i<4;i++)
{
if(gg[i].fi!=gg[0].fi) break;
node[now].mn.se+=gg[i].se;
}
if(i<4)
{
node[now].m2=gg[i];
for(i++;i<4;i++)
{
if(gg[i].fi!=node[now].m2.fi) break;
node[now].m2.se+=gg[i].se;
}
}
else node[now].m2=mp(INF,0);
}
inline void down(ll now)
{
ll L=node[now].ls,R=node[now].rs;
if(node[now].dw)
{
node[L].add(node[now].dw);
node[R].add(node[now].dw);
node[now].dw=0;
}
}
void build(ll now,ll l,ll r)
{
if(l==r)
{
node[now].mn=mp(0,1);
node[now].m2=mp(INF,0);
return;
}
ll mid=((l+r)>>1);
node[now].ls=++tt;
build(tt,l,mid);
node[now].rs=++tt;
build(tt,mid+1,r);
up(now);
}
void add(ll now,ll l,ll r,ll u,ll v,ll w)
{
if(u<=l&&r<=v)
{
node[now].add(w);
return;
}
down(now);
ll mid=((l+r)>>1);
if(u<=mid) add(node[now].ls,l,mid,u,v,w);
if(mid<v) add(node[now].rs,mid+1,r,u,v,w);
up(now);
}
ll ask(ll now,ll l,ll r,ll u,ll v)
{
if(u<=l&&r<=v)
{
ll res=0;
if(node[now].mn.fi<=2) res+=node[now].mn.se;
if(node[now].m2.fi<=2) res+=node[now].m2.se;
return res;
}
down(now);
ll res=0,mid=((l+r)>>1);
if(u<=mid) res+=ask(node[now].ls,l,mid,u,v);
if(mid<v) res+=ask(node[now].rs,mid+1,r,u,v);
return res;
}
int main()
{
ll i,j;
cin>>n;
for(i=1;i<=n;i++)
{
scanf("%lld",&num[i]);
pos[num[i]]=i;
}
build(tt=1,1,n);
for(i=1;i<=n;i++)
{
add(1,1,n,1,i,1);
if(pos[i]>1&&num[pos[i]-1]<i) add(1,1,n,1,num[pos[i]-1],-1);
if(pos[i]<n&&num[pos[i]+1]<i) add(1,1,n,1,num[pos[i]+1],-1);
if(i>1) ans+=ask(1,1,n,1,i-1);
}
cout<<ans;
}