两个数列 {An} , {Bn} ,请求出Ans, Ans定义如下:
Ans:=Σni=1Σnj=i[max{Ai,Ai+1,...,Aj}=max{Bi,Bi+1,...,Bj}]
注:[ ]内表达式为真,则为1,否则为0.
1≤N≤3.5×1051≤Ai,Bi≤N
样例解释:
7个区间分别为:(1,4),(1,5),(2,4),(2,5),(3,3),(3,5),(4,5)
Input
第一行一个整数N 第二行N个整数Ai 第三行N个整数Bi
Output
一行,一个整数Ans
Input示例
5 1 4 2 3 4 3 2 2 4 1
Output示例
7
shlw
(题目提供者)
题意:给定A和B数组,问有几对<i,j>满足a[i]~a[j]的最大值和b[i]~b[j]的最大值相等,(公共区间的最大值相等)。
思路:官方是单调栈+数组拉链O(n)解决,先留坑。考虑分治,计算以mid为分界的合法区间数。
有如下定义
La[i]为max(a[i]...a[mid])
Lb[i]为max(b[i]...b[mid])
Ra[j]为max(a[mid+1]...a[j])
Rb[j]为max(b[mid+1]...b[j])。
遍历i:l~mid,如果La[i] > Lb[i],那么mid右边哪些区间是符合呢?显然是Rb[j]==La[i] && Rb[j] >= Ra[j]的区间,小于也同理,等于的话前面两种都符合。为了防止计算重复,Ra[j]==Rb[j] && max(La[i],Lb[i]) < Ra[j]的情况另外计算,比如(1,2,8)和(3, 2, 8)这种,即max值在右半边区间取得,这个满足单调性可以O(n)解决,当然也可以统一用二分解决。
# include <bits/stdc++.h>
using namespace std;
const int maxn = 350000;
int a[maxn+3], b[maxn+3];
long long ans = 0;
void $(int l, int r, int L, int R)
{
int x=0, y=0;
vector<int>p,q,w;
for(int i=L; i<=R; ++i)
{
x = max(x, a[i]);
y = max(y, b[i]);
if(x>y) p.push_back(x);
else if(x<y) q.push_back(y);
else w.push_back(x);
}
x = y = 0;
for(int i=r; i>=l; --i)
{
x = max(x, a[i]);
y = max(y, b[i]);
if(x>y)
{
ans += upper_bound(q.begin(), q.end(), x)-lower_bound(q.begin(), q.end(), x);
ans += upper_bound(w.begin(), w.end(), x)-lower_bound(w.begin(), w.end(), x);
}
else if(x<y)
{
ans += upper_bound(p.begin(), p.end(), y) - lower_bound(p.begin(), p.end(), y);
ans += upper_bound(w.begin(), w.end(), y)-lower_bound(w.begin(), w.end(), y);
}
else
{
ans += upper_bound(p.begin(), p.end(), x)-p.begin();
ans += upper_bound(q.begin(), q.end(), x)-q.begin();
ans += upper_bound(w.begin(), w.end(), max(x,y))-w.begin();
}
}
x = y = 0;
int xx=0, yy=0;
int k=r+1;
for(int i=L; i<=R; ++i)//最大值在右半边区间取得的情况。
{
x = max(x, a[i]);
y = max(y, b[i]);
if(x != y) continue;
while(k-1>=l && max(max(xx,a[k-1]), max(yy,b[k-1])) < x)
--k, xx=max(xx, a[k]), yy=max(yy,b[k]);
ans += r+1-k;
}
}
void solve(int l, int r)
{
if(l==r)
{
ans+=(a[l]==b[l]);
return;
}
int mid = l+r>>1;
$(l, mid, mid+1, r);
solve(l, mid);
solve(mid+1, r);
}
int main()
{
int n;
scanf("%d",&n);
for(int i=1; i<=n; ++i) scanf("%d",&a[i]);
for(int i=1; i<=n; ++i) scanf("%d",&b[i]);
solve(1, n);
printf("%lld\n",ans);
return 0;
}