首先分治,然后答案转变为求区间[l,r]中,经过终点mid=(l+r)>>1的子串[x,y]的答案之和。
那么不妨枚举左端点为x,那么显然可以得到区间[x,mid]的最小值u和最大值v。同时维护两个指针j,k,表示最远的j使得[mid+1,j]的最小值没有u小,k维护最大值。那么:
1.对于所有y∈[mid+1,min(j,k)],最小值为u,最大值为v,直接利用高斯求和得到答案;
2.对于所有y∈[max(j,k)+1,r],最小值为[mid+1,y]的最小值,最大值为[mid+1,y]的最大值,预处理所有[mid+1,r]的答案得到;
3.不妨设j<k,那么对于所有y∈[j+1,k],最小值为[mid+1,y]的最小值,最大值为v,预处理所有[mid+1,r]的最大值为这部分答案的影响即可。
时间O(NlogN)。
AC代码如下:
#include<iostream>
#include<cstdio>
#include<cstring>
#define mod 1000000000
#define ll long long
#define N 500005
using namespace std;
int n,ans,a[N],c[N][2],f[N],g[N],p[N][2],q[N][2];
int read(){
int x=0; char ch=getchar();
while (ch<'0' || ch>'9') ch=getchar();
while (ch>='0' && ch<='9'){ x=x*10+ch-'0'; ch=getchar(); }
return x;
}
void ad(int &x,int y){ x+=y; if (x>=mod) x-=mod; }
void dl(int &x,int y){ x-=y; if (x<0) x+=mod; }
int getsum(int x,int y){ return ((ll)(x+y)*(y-x+1)>>1)%mod; }
void solve(int l,int r){
if (l==r){ ad(ans,(ll)a[l]*a[l]%mod); return; }
int mid=(l+r)>>1,i;
solve(l,mid); solve(mid+1,r);
c[mid][0]=c[mid][1]=a[mid];
for (i=mid-1; i>=l; i--){
c[i][0]=min(c[i+1][0],a[i]); c[i][1]=max(c[i+1][1],a[i]);
}
int mn=mod,mx=-mod;
f[mid]=g[mid]=p[mid][0]=p[mid][1]=q[mid][0]=q[mid][1]=0;
for (i=mid+1; i<=r; i++){
mn=min(mn,a[i]); mx=max(mx,a[i]);
f[i]=(ll)mn*mx%mod*(i-mid)%mod; ad(f[i],f[i-1]);
g[i]=(ll)mn*mx%mod; ad(g[i],g[i-1]);
p[i][0]=(p[i-1][0]+mn)%mod; q[i][0]=(q[i-1][0]+mx)%mod;
p[i][1]=(ll)mn*(i-mid)%mod; ad(p[i][1],p[i-1][1]);
q[i][1]=(ll)mx*(i-mid)%mod; ad(q[i][1],q[i-1][1]);
}
int j=mid,k=mid;
for (i=mid; i>=l; i--){
while (j<r && c[i][0]<a[j+1]) j++;
while (k<r && c[i][1]>a[k+1]) k++;
ad(ans,(ll)c[i][0]*c[i][1]%mod*getsum(mid-i+2,min(j,k)-i+1)%mod);
ad(ans,((ll)g[r]*(mid-i+1)+f[r])%mod);
dl(ans,((ll)g[max(j,k)]*(mid-i+1)+f[max(j,k)])%mod);
if (j<k){
ad(ans,((ll)p[k][0]*(mid-i+1)+p[k][1])%mod*c[i][1]%mod);
dl(ans,((ll)p[j][0]*(mid-i+1)+p[j][1])%mod*c[i][1]%mod);
} else{
ad(ans,((ll)q[j][0]*(mid-i+1)+q[j][1])%mod*c[i][0]%mod);
dl(ans,((ll)q[k][0]*(mid-i+1)+q[k][1])%mod*c[i][0]%mod);
}
}
}
int main(){
n=read(); int i;
for (i=1; i<=n; i++) a[i]=read();
solve(1,n); printf("%d\n",ans);
return 0;
}
by lych
2016.4.20