题目
给出一个长为n(n<=1e6)的排列p[],
规定一个区间 [l,r] (l<=r) 是 fair 的,当且仅当区间中最小值等于 l 且最大值等于 r
求 fair 区间的个数。
思路来源
https://ac.nowcoder.com/acm/contest/view-submission?submissionId=44985124
题解1
单调栈,预处理出每个位置作为最小值/最大值
所到的最左/最右的地方,分别记为mnl,mnr,mxl,mxr
考虑[4,1,2,3],区间重排之后最小值是1,最大值是4,所以(1,4)这个值对答案有贡献,
换言之,[1,4]中1是最小值,[1,4]中4是最大值,
即,[1,4]包含于[mnl[1所在位置],mnr[1所在位置]],[1,4]包含于[mxl[4所在位置],mxr[4所在位置]]
也即,
1包含于[mnl[1所在位置],mnr[1所在位置]]①,
4包含于[mnl[1所在位置],mnr[1所在位置]]②
1包含于[mxl[4所在位置],mxr[4所在位置]]③,
4包含于[mxl[4所在位置],mxr[4所在位置]]④
首先,检查a[i]是不是在[mnl[i],mnr[i]]里,这样的值才能成为区间左端点,
处理成最小值[mnl[i],mnr[i]]线段,最大值线段同理
然后考虑合法的最小值线段,扫描线思想,左端点+1,右端点-1,
在[mnl[i],mnr[i]]值域内的右端点x,可以以ai为左端点
增序扫描值域,
一边动态维护最小值线段,把最小值ai单点插到BIT上,
一边用最大值线段去查答案,
假设当前枚举到的最大值是x,说明x已经落在之前插入的最小值线段里了,
如果此时最小值线段里的点ai,也落在x的最大值线段里,说明(ai,x)是一个合法的点对,
此时,用BIT区间求和[mxl[x],mxr[x]]内合法的值,加到答案里即可
①-④四条需要同时满足,才有对答案的贡献,
①、④不满足就没有对应的最小值/最大值线段,
③不满足时,4到来时,1的最小值线段已经被删了
②不满足时,4统计答案时,就不会把1计入答案
代码1
#include<bits/stdc++.h>
using namespace std;
const int N=1e6+10;
typedef long long ll;
int n,a[N],stk[N],mnl[N],mnr[N],mxl[N],mxr[N],c;
vector<int>l[N],r[N],w[N];
ll ans;
struct BIT{
int n,tr[N];
void init(int _n){
n=_n;
memset(tr,0,sizeof tr);
}
void add(int x,int v){
for(int i=x;i<=n;i+=i&-i)
tr[i]+=v;
}
int sum(int x){
int ans=0;
for(int i=x;i;i-=i&-i)
ans+=tr[i];
return ans;
}
}tr;
int main(){
scanf("%d",&n);
for(int i=1;i<=n;++i){
scanf("%d",&a[i]);
}
a[0]=0;
stk[c=1]=0;
for(int i=1;i<=n;++i){
while(c && a[stk[c]]>a[i]){
c--;
}
mnl[i]=stk[c]+1;
stk[++c]=i;
}
a[n+1]=0;
stk[c=1]=n+1;
for(int i=n;i>=1;--i){
while(c && a[stk[c]]>a[i]){
c--;
}
mnr[i]=stk[c]-1;
stk[++c]=i;
}
a[0]=n+1;
stk[c=1]=0;
for(int i=1;i<=n;++i){
while(c && a[stk[c]]<a[i]){
c--;
}
mxl[i]=stk[c]+1;
stk[++c]=i;
}
a[n+1]=n+1;
stk[c=1]=n+1;
for(int i=n;i>=1;--i){
while(c && a[stk[c]]<a[i]){
c--;
}
mxr[i]=stk[c]-1;
stk[++c]=i;
}
for(int i=1;i<=n;++i){
if(mnl[i]<=a[i] && a[i]<=mnr[i]){
l[mnl[i]].push_back(a[i]);
r[mnr[i]].push_back(a[i]);
}
if(mxl[i]<=a[i] && a[i]<=mxr[i]){
w[a[i]].push_back(i);
}
}
tr.init(n);
for(int i=1;i<=n;++i){
for(auto &x:l[i])tr.add(x,1);
for(auto &x:w[i])ans+=tr.sum(mxr[x])-tr.sum(mxl[x]-1);
for(auto &x:r[i])tr.add(x,-1);
}
printf("%lld\n",ans);
return 0;
}
题解2
考虑分治,mn和mx用于维护[l,mid]的后缀最小值和后缀最大值,[mid+1,r]的前缀最小值和前缀最大值
根据最小值、最大值的位置,跨区间的答案分为四种,左左、右右、左右、右左
同侧的情况,
①左左,在左半区间枚举区间最小值mn[i],判断[mn[i],mx[i]]是不是一个合法的跨区间
②右右,在右半区间枚举区间最大值mx[i],判断[mn[i],mx[i]]是不是一个合法的跨区间
异侧的情况,枚举最大值,
③左右,最大值在右,对于[ql,qr)里的x,[x,mx[i]]都合法,有点单调双指针的意思,
先向左延展ql,统计mx[ql]<mx[i]时,mn[ql]==ql的答案,但此时可能mn[i]<mn[ql]使答案不合法,
再向左延展qr,将这些不合法的答案减掉
④右左,最大值在左,区间只能是[i,mx[i]],分别判断左半段和右半段的值域是否合法
代码2
#include<bits/stdc++.h>
using namespace std;
const int N=1e6+10;
typedef long long ll;
int n,a[N],mn[N],mx[N];
ll solve(int l,int r){
if(l==r){
//printf("l:%d r:%d ans:%lld\n",l,r,a[l]==l);
return a[l]==l;
}
int mid=(l+r)/2;
ll ans=solve(l,mid)+solve(mid+1,r);
mn[mid]=mx[mid]=a[mid];
mn[mid+1]=mx[mid+1]=a[mid+1];
for(int i=mid-1;i>=l;--i){
mn[i]=min(mn[i+1],a[i]);
mx[i]=max(mx[i+1],a[i]);
}
for(int i=mid+2;i<=r;++i){
mn[i]=min(mn[i-1],a[i]);
mx[i]=max(mx[i-1],a[i]);
}
//printf("l:%d r:%d ans:%lld\n",l,r,ans);
//最大值、最小值同侧
//左侧 枚举最小值端点
for(int i=mid;i>=l;--i){
if(mn[i]!=i)continue;
//[i,mx[i]]
if(mid+1<=mx[i] && mx[i]<=r){
int id=mx[i];
//另一端值域是[i,mx[i]]的子区间
if(i<=mn[id] && mx[id]<=id){
ans++;
}
}
}
//printf("l:%d r:%d ans:%lld\n",l,r,ans);
//右侧 枚举最大值端点
for(int i=mid+1;i<=r;++i){
if(mx[i]!=i)continue;
//[mn[i],i]
if(l<=mn[i] && mn[i]<=mid){
int id=mn[i];
if(id<=mn[id] && mx[id]<=i){
ans++;
}
}
}
//printf("l:%d r:%d ans:%lld\n",l,r,ans);
//最大值、最小值异侧 枚举最大值
//最大值位于右侧
int ql=mid,qr=mid;
ll res=0;
for(int i=mid+1;i<=r;++i){
if(mx[i]!=i)continue;
//[ql,i] 但可能计入一些mn[i]<ql的非法答案
while(ql>=l && mx[ql]<=mx[i]){
if(mn[ql]==ql)res++;
ql--;
}
//减掉非法答案 [ql,qr)内的x是mn[x]<=mn[i]的合法答案
while(qr>ql && mn[qr]>mn[i]){
if(mn[qr]==qr)res--;
qr--;
}
ans+=res;
}
//printf("l:%d r:%d ans:%lld\n",l,r,ans);
//最大值位于左侧
for(int i=mid;i>=l;--i){
//[i,mx[i]]
if(mid+1<=mx[i] && mx[i]<=r){
int id=mx[i];
if(i<mn[i] && mn[id]==i && mx[id]<id){
ans++;
}
}
}
//printf("l:%d r:%d ans:%lld\n",l,r,ans);
return ans;
}
int main(){
scanf("%d",&n);
for(int i=1;i<=n;++i){
scanf("%d",&a[i]);
}
printf("%lld\n",solve(1,n));
return 0;
}