Alice has a magic array. She suggests that the value of a interval is equal to the sum of the values in the interval, multiplied by the smallest value in the interval.
Now she is planning to find the max value of the intervals in her array. Can you help her?
Input
First line contains an integer n(1 \le n \le 5 \times 10 ^5n(1≤n≤5×105).
Second line contains nn integers represent the array a (-10^5 \le a_i \le 10^5)a(−105≤ai≤105).
Output
One line contains an integer represent the answer of the array.
样例输入复制
5 1 2 3 4 5
样例输出复制
36
大致题意为有一个数组,数组有n个元素,每个元素有相应的大小,定义一个计算为选定一个区间,用这个区间内的最小值乘以这个区间内的元素加和,最后要求出所有这种计算所取得的最大值。
首先对题意进行分析,用区间内的最小值乘以区间和,那么区间应该怎么找,如果要最大化这个答案,很明显区间内的最小值是固定的,变化的只是以这个值为最小值的区间,那么很明显区间和应该越大越好;我们可以通过区间最小值来确定包含该最小值的最大区间和,由于该区间必须包含该最小值,所以我们可以利用前缀和,如果该最小值大于0,从该最小值的右边区间找到一个最大前缀和,从左边区间找出一个最小前缀和,然后做差与该区间的最小值相乘,就是这个区间所能贡献的最大答案;如果最小值小于零则反之处理。然后以这个值为最小值的区间已经找过,那么以后的i计算都不应该包含该值,所以可以由区间最小值来进行区间划分,然后归并处理,即可得到最终答案。
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cmath>
#include <cstdlib>
#include <cstring>
#include <map>
#include <stack>
#include <queue>
#include <vector>
#include <bitset>
#include <set>
#include <utility>
#include <sstream>
#include <iomanip>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const int MAX = (int)5e5 + 15;
const ll mod = 1e9+7;
pair<int,int> num[MAX];
pair<int,int> st1[MAX][20];
ll psum[MAX];
ll st2[MAX][20];
ll st3[MAX][20];
void ST(int n) // ST表预处理最大最小值,实现快速查询
{
for(int i=1;i<=n;i++){
st1[i][0]=num[i];
st2[i][0]=psum[i];
st3[i][0]=psum[i];
}
st1[0][0]=st1[1][0];
for(int k=1;k<=(int)log2(n);k++){
for(int i=0;i<=n-(1<<k)+1;i++){
st1[i][k]=min(st1[i][k-1],st1[i+(1<<(k-1))][k-1]);
st2[i][k]=min(st2[i][k-1],st2[i+(1<<(k-1))][k-1]);
st3[i][k]=max(st3[i][k-1],st3[i+(1<<(k-1))][k-1]);
}
}
}
ll div2(int l,int r) //归并查询
{
if(l>=r){
if(r<=0)return (ll)num[l].first*(ll)num[l].first;
else return (ll)num[r].first*(ll)num[r].first;
}
int k=log2(r-l+1);
pair<int,int> temp=min(st1[l][k],st1[r+1-(1<<k)][k]);
int mid=temp.second;
ll numr,numl;
ll nape=temp.first;
if(nape>=0){
k=(int)log2(r-mid+1);
numr=max(st3[mid][k],st3[r+1-(1<<k)][k]);
k=(int)log2(mid-l+2);
numl=min(st2[l-1][k],st2[mid+1-(1<<k)][k]);
}else{
k=(int)log2(r-mid+1);
numr=min(st2[mid][k],st2[r+1-(1<<k)][k]);
k=(int)log2(mid-l+2);
numl=max(st3[l-1][k],st3[mid+1-(1<<k)][k]);
}
nape=nape*(numr-numl);
return max(nape,max(div2(l,mid-1),div2(mid+1,r)));
}
int main()
{
int n;
cin>>n;
for(int i=1;i<=n;i++){
cin>>num[i].first;
num[i].second=i;
psum[i]=psum[i-1]+num[i].first;
}
ST(n);
cout<<div2(1,n)<<endl;
return 0;
}