Description
现在有n个数字a[1…n],对于一个长度至少为3的区间[l,r],定义区间的价值为区间中最大的三个数的乘积。求所有区间的价值和。
Input
从文件three.in读入。
第一行一个整数n。第二行n个整数,表示a[1…n]
Output
输出到文件three.out。
一行一个整数,表示所有区间的价值和,答案对10^9+7取模。
Sample Input
5
1 2 3 4 5
Sample Output
234
Data Constraint
对于所有数据,保证n<=106,0<ai<=109,所有数字互不相同
对于10%的数据,n<=100
对于30%的数据,n<=1000
对于70%的数据,n<=10^5
对于100%的数据,n<=10^6
题解
话说好久没打博客了 (csdn怎么变美观了?)
首先30分都很简单,直接顺推过去即可。
然后我们发现30分都是维护三个最大值,那么我们考虑定下这三个最大值。
如果是定下三个最大值,那么他们所选的区间必定是左边顶到一个比他们还大的值,右边同理。
那么我们考虑从大到小来把值插进去原序列。
那么当前插进去的值对答案所造成的贡献只有下面三种情况:
1、左右各选一个
2、左边选两个
3、右边选两个
那么问题就变成如何快速找这些点了。
目前已知有三种方法,set、并查集、线段树。
我比较菜,只会线段树。
那么我们为了便于维护,我们在头尾各插入3个数。
然后线段树维护区间siz,树上二分即可。
时间复杂度大概是 O ( n l o g n ∗ 大 常 数 ) O(n\ log\ n*大常数) O(n log n∗大常数)
代码
#include <iostream>
#include <cstdio>
#include <cmath>
#include <cstring>
using namespace std;
const long long mo=1000000007;
const int maxn=1000010;
int n,m,a[maxn],t[maxn*4],siz[maxn*4],pm;
long long a1,b1,ans,x[7],y[7],id[maxn];
void qsort(int l,int r)
{
int i=l;int j=r;
int m=a[(i+j)/2];
while (i<=j)
{
while (a[i]>m) i++;
while (a[j]<m) j--;
if (i<=j)
{
swap(a[i],a[j]);
swap(id[i],id[j]);
i++;j--;
}
}
if (l<j) qsort(l,j);
if (r>i) qsort(i,r);
}
void insert(int x,int l,int r,int st,int gg)
{
if (l==r)
{
siz[x]=1;
t[x]=gg;
}
else
{
int mid=(l+r)/2;
if (st<=mid)
{
insert(x*2,l,mid,st,gg);
}
else
{
insert(x*2+1,mid+1,r,st,gg);
}
siz[x]=siz[x*2]+siz[x*2+1];
}
}
void getpm(int x,int l,int r,int st)
{
if (l==r)
{
return;
}
else
{
int mid=(l+r)/2;
if (st<=mid)
{
getpm(x*2,l,mid,st);
}
else
{
pm+=siz[x*2];
getpm(x*2+1,mid+1,r,st);
}
}
}
void find(int x,int l,int r,int ss)
{
if (l==r)
{
a1=t[x];
b1=l;
}
else
{
int mid=(l+r)/2;
if (siz[x*2]>=ss)
{
find(x*2,l,mid,ss);
}
else
{
find(x*2+1,mid+1,r,ss-siz[x*2]);
}
}
}
int main()
{
freopen("three.in","r",stdin);
freopen("three.out","w",stdout);
scanf("%d",&n);
for (int i=1;i<=n;i++)
{
scanf("%d",&a[i]);
id[i]=i+3;
}
qsort(1,n);
insert(1,1,n+6,1,0);
insert(1,1,n+6,2,0);
insert(1,1,n+6,3,0);
insert(1,1,n+6,n+6,0);
insert(1,1,n+6,n+5,0);
insert(1,1,n+6,n+4,0);
insert(1,1,n+6,id[1],a[1]);
insert(1,1,n+6,id[2],a[2]);
for (int i=3;i<=n;i++)
{
a1=0;
b1=0;
pm=0;
getpm(1,1,n+6,id[i]);
find(1,1,n+6,pm);
x[1]=a1;y[1]=b1;
find(1,1,n+6,pm-1);
x[2]=a1;y[2]=b1;
find(1,1,n+6,pm-2);
x[3]=a1;y[3]=b1;
find(1,1,n+6,pm+1);
x[4]=a1;y[4]=b1;
find(1,1,n+6,pm+2);
x[5]=a1;y[5]=b1;
find(1,1,n+6,pm+3);
x[6]=a1;y[6]=b1;
ans=(ans+x[1]*x[4]%mo*a[i]%mo*(y[1]-y[2])%mo*(y[5]-y[4])%mo)%mo;
ans=(ans+x[1]*x[2]%mo*a[i]%mo*(y[2]-y[3])%mo*(y[4]-id[i])%mo)%mo;
ans=(ans+x[4]*x[5]%mo*a[i]%mo*(id[i]-y[1])%mo*(y[6]-y[5])%mo)%mo;
insert(1,1,n+6,id[i],a[i]);
}
printf("%lld\n",ans);
}