题目描述
给定长度为n的序列{a},求
∑
i
=
1
n
∑
j
=
i
n
M
a
x
[
i
,
j
]
∗
M
i
n
[
i
,
j
]
\sum_{i=1}^{n}\sum_{j=i}^{n}Max[i,j]*Min[i,j]
∑i=1n∑j=inMax[i,j]∗Min[i,j]
Max和Min为对应区间内的极值
数据范围
1 ≤ \leq ≤n ≤ \leq ≤ 1e5,a[i] ≤ \leq ≤???(反正不用高精度)
题解
不要问为什么用分治,我也不知道
显然,如果每次将当前区间的最大/小值作为分治点,虽然可以省事,但在极端情况下(有序序列)会被卡成O
(
n
2
)
(n^2)
(n2)。
所以,老老实实分治中点吧。
设当前讨论区间为[l,r],中点为mid。根据分治原理,我们只要处理完跨中点的贡献,就可以递归处理了。。(记得特判长度为1的情况)
我们处理出左区间[l,mid]的后缀极值,及右区间[mid+1,r]的前缀极值。。。
显然4个数组都具有单调性。。。
所以我们可以仿照CDQ分治的套路,移动两个指针来解决。。
具体。。。我们有pmax和pmin两个指针(整形变量,只是指针的效果而已)。当讨论到i作为左端点的跨中点区间时,pmax和pmin指向能让区间最大(小)值与[i,mid]时相等的最右端点。。。说简单点,
M
a
x
[
i
,
p
m
a
x
]
=
=
M
a
x
[
i
,
m
i
d
]
,
M
a
x
[
i
,
p
m
a
x
+
1
]
>
M
a
x
[
i
,
m
i
d
]
Max[i,pmax]==Max[i,mid],Max[i,pmax+1]>Max[i,mid]
Max[i,pmax]==Max[i,mid],Max[i,pmax+1]>Max[i,mid]
最小值同理。。。并且可以发现,当i减小时,pmax和pmin不会向左移动
然后我们考虑如何计算对答案的贡献。。。
如下图,是一种情况(pmax< pmin时同理)
对于第1部分,显然这些位置作为右端点时极值都在左边,所以直接用点i的前缀极值相乘,再乘上区间长度即可。
对于第2部分,此时最小值已经不取决于区间[i,mid],但最大值仍然是。我们可以处理出右区间前缀极值的前缀和(听起来好恶心),贡献为
M
a
x
[
i
,
m
i
d
]
∗
(
M
i
n
S
u
m
[
p
m
a
x
]
−
M
i
n
S
u
m
[
p
m
i
n
]
)
Max[i,mid]*(MinSum[pmax]-MinSum[pmin])
Max[i,mid]∗(MinSum[pmax]−MinSum[pmin])。
对于第三部分,此时两种极值都不取决于[i,mid]。我们还可以处理出前缀极值相乘的前缀和(啊啊啊好恶心),那么贡献为
A
l
l
S
u
m
[
r
]
−
A
l
l
S
u
m
[
p
m
a
x
]
AllSum[r]-AllSum[pmax]
AllSum[r]−AllSum[pmax]。
具体实现参见代码。
代码
1.暴力
#include <cstdio>
#include <iostream>
using namespace std;
const int Q=300005;
int a[Q];
int owo(int l,int r)
{
int p1=-2100000000,p2=2100000000;
for(int i=l;i<=r;i++)
p1=max(p1,a[i]),p2=min(p2,a[i]);
return p1*p2;
}
int main()
{
int i,j,n,ans=0;
scanf("%d",&n);
for(i=1;i<=n;i++)
scanf("%d",&a[i]);
for(i=1;i<=n;i++)
for(j=i;j<=n;j++)
ans+=owo(i,j);
printf("%d",ans);
return 0;
}
2.分治
//附上一组(良心)数据
#include <cstdio>
#include <iostream>
#include <algorithm>
#define int long long
using namespace std;
const int Q=300005;
int ans=0,mx[Q],mi[Q],a[Q],alx[Q],aln[Q],sum[Q];
void solve(int l,int r)
{
if(l==r){
ans+=a[l]*a[l];
return;
}
int mid=l+r >>1,px,pn;
solve(mid+1,r);
solve(l,mid);
int ops=ans;
mx[mid]=mi[mid]=a[mid];
mx[mid+1]=mi[mid+1]=a[mid+1];
for(int i=mid-1;i>=l;i--)mx[i]=max(mx[i+1],a[i]),mi[i]=min(mi[i+1],a[i]);
for(int i=mid+2;i<=r;i++)mx[i]=max(mx[i-1],a[i]),mi[i]=min(mi[i-1],a[i]);
aln[mid]=alx[mid]=sum[mid]=0;
for(int i=mid+1;i<=r;i++)sum[i]=sum[i-1]+mx[i]*mi[i],alx[i]=alx[i-1]+mx[i],aln[i]=aln[i-1]+mi[i];
px=pn=mid+1;
for(int i=mid;i>=l;i--)
{
while(px<=r&&mx[i]>=mx[px])++px;
while(pn<=r&&mi[i]<=mi[pn])++pn;
int maxn=px-1,minn=pn-1;
ans+=(min(maxn,minn)-mid)*mx[i]*mi[i];
ans+=sum[r]-sum[max(maxn,minn)];
if(maxn>minn)ans+=mx[i]*(aln[maxn]-aln[minn]);
else ans+=mi[i]*(alx[minn]-alx[maxn]);
}
}
main()
{
int i,n;
scanf("%lld",&n);
for(i=1;i<=n;i++)
scanf("%lld",&a[i]);
solve(1,n);
printf("%lld",ans);
return 0;
}
/*
6
1 4 14 5 2 4
*/