题目
题解
题意
给一数组
a
a
a,记录下每个连续子数组的最大值
有
Q
Q
Q次询问,每个询问将问:记录的数中有多少个数是小于/大于/等于给出的
K
K
K的
分析
既然要求子数组最大值
那么对于一个位置
i
(
1
≤
i
≤
n
)
i(1≤i≤n)
i(1≤i≤n),可以先暴力往左右两边找,找到最远处使得这个区间内的最大值是
a
[
i
]
a[i]
a[i]
可以用线段树记录最大值,然后区间查找
但是这个时间复杂度是
O
(
n
2
l
o
g
(
n
)
)
O(n^2log(n))
O(n2log(n))的
考虑优化
对于最远处,可以用二分来找
优化成
O
(
n
l
o
g
(
n
2
)
)
O(nlog(n^2))
O(nlog(n2))
那么
i
i
i对答案的贡献就是(
i
i
i-最左端的下标+1)*(最右端的下标-
i
i
i+1),记录下来
(
n
u
m
[
i
]
)
(num[i])
(num[i])
将
a
[
i
]
a[i]
a[i]排序(同时更改
n
u
m
[
i
]
num[i]
num[i])
记录
n
u
m
num
num的 前缀和 和 后缀和
对于每个询问
分类讨论
如果是小于
二分找到最大的下标
(
x
)
(x)
(x)使得
a
[
x
]
<
K
a[x]<K
a[x]<K,输出
x
x
x的前缀和
如果是大于
二分找到最小的下标
(
x
)
(x)
(x)使得
a
[
x
]
>
K
a[x]>K
a[x]>K,输出
x
x
x的后缀和
如果是等于
先二分找到最大的
x
x
x使得
a
[
x
]
=
K
a[x]=K
a[x]=K,再二分找到最大的
y
y
y使得
a
n
s
[
y
]
<
K
ans[y]<K
ans[y]<K,输出
x
x
x的前缀和-
y
y
y的前缀和
Code
#include<cstdio>
#include<iostream>
#include<algorithm>
using namespace std;
int n,m,i,l,r,mid,bj;
long long s1,s2,x,ans,ans1,ans2,pre[100005],suf[100005],tree[400005];
char ch;
struct node
{
long long val,sum;
}a[100005];
bool cmp(node x,node y)
{
return x.val<y.val;
}
void biuld(int now,int l,int r)
{
if (l==r)
{
tree[now]=a[l].val;
return;
}
int mid=(l+r)>>1;
biuld(now<<1,l,mid);
biuld(now<<1|1,mid+1,r);
tree[now]=max(tree[now<<1],tree[now<<1|1]);
}
long long query(int now,int l,int r,int p,int q)
{
if (l>=p&&r<=q) return tree[now];
int mid=(l+r)>>1;
long long res;
res=0;
if (p<=mid) res=max(res,query(now<<1,l,mid,p,q));
if (q>mid) res=max(res,query(now<<1|1,mid+1,r,p,q));
return res;
}
int main()
{
freopen("jxthree.in","r",stdin);
freopen("jxthree.out","w",stdout);
scanf("%d%d",&n,&m);
for (i=1;i<=n;i++)
scanf("%lld",&a[i].val);
biuld(1,1,n);
for (i=1;i<=n;i++)
{
l=1;
r=i-1;
s1=i;
s2=i;
while (l<=r)
{
mid=(l+r)>>1;
if (query(1,1,n,mid,i-1)<a[i].val)
{
s1=mid;
r=mid-1;
}
else l=mid+1;
}
l=i+1;
r=n;
while (l<=r)
{
mid=(l+r)>>1;
if (query(1,1,n,i+1,mid)<=a[i].val)
{
s2=mid;
l=mid+1;
}
else r=mid-1;
}
a[i].sum=(long long)(i-s1+1)*(long long)(s2-i+1);
}
sort(a+1,a+n+1,cmp);
for (i=1;i<=n;i++)
pre[i]=pre[i-1]+a[i].sum;
for (i=n;i>=1;i--)
suf[i]=suf[i+1]+a[i].sum;
for (i=1;i<=m;i++)
{
ch=getchar();
while (ch!='>'&&ch!='='&&ch!='<') ch=getchar();
if (ch=='<') bj=1;
if (ch=='=') bj=2;
if (ch=='>') bj=3;
scanf("%lld",&x);
ans=0;
ans1=0;
ans2=0;
if (bj==1)
{
l=1;
r=n;
while (l<=r)
{
mid=(l+r)>>1;
if (a[mid].val<x)
{
ans=mid;
l=mid+1;
}
else r=mid-1;
}
printf("%lld\n",pre[ans]);
}
if (bj==2)
{
l=1;
r=n;
while (l<=r)
{
mid=(l+r)>>1;
if (a[mid].val<=x)
{
ans1=mid;
l=mid+1;
}
else r=mid-1;
}
l=1;
r=n;
while (l<=r)
{
mid=(l+r)>>1;
if (a[mid].val<x)
{
ans2=mid;
l=mid+1;
}
else r=mid-1;
}
if (a[ans1].val==x) printf("%lld\n",pre[ans1]-pre[ans2]);
else printf("0\n");
}
if (bj==3)
{
l=1;
r=n;
while (l<=r)
{
mid=(l+r)>>1;
if (a[mid].val>x)
{
ans=mid;
r=mid-1;
}
else l=mid+1;
}
printf("%lld\n",suf[ans]);
}
}
fclose(stdin);
fclose(stdout);
return 0;
}