假设a中的元素互不相同,我们考虑a中的某个元素作为min的时刻。 • 对于每个a[i],我们找到左边第一个比它小的元素a[l],右边第一个比它小的a[r] • 那么左端点在[l+1,i],右端点在[i,r-1]的区间min就为它。 • 求l和r可以使用单调栈。 • 考虑如何求某个a[i]作为min的答案。 • 如果a[i]>0我们就是要最大化sum(b[l…r])。a[i]<0就是要最小化。 • 记b的前缀和为s,那么sum(b[l…r])=s[r]-s[l-1]。 • 所以我们只要查询i…r-1最大的s和l…i-1最小的s,相减即可。 • 查询区间最小值可以使用st表或线段树等数据结构。 • 也可以直接建立笛卡尔树,然后维护子树中s的最大最小值,从而做到O(n)的复杂度。
#include<iostream>
#include<cstring>
#include<cstdio>
#include<queue>
#include<cstdlib>
#include<cmath>
#include<stack>
#include<map>
#include<string>
#include<vector>
#include<set>
#include<bitset>
#include<algorithm>
using namespace std;
#define ll long long
#define INF 0x3f3f3f3f
#define ull unsigned long long
#define endl '\n'
#define clr(a) memset(a, 0, sizeof(a))
#define lowbit(x) x & -x
#define lson rt << 1, l, mid
#define rson rt << 1 | 1, mid + 1, r
#define PB push_back
#define POP pop_back
const double pi = acos(-1);
const int maxn = 3e6 + 101;
const int maxm = 100 + 101;
const ll mod = 1e9 + 7;
const int hash_mod = 19260817;
int n;
ll a[maxn], b[maxn], sum[maxn];
ll mx[maxn<<2], mi[maxn<<2];
void build(int rt, int l, int r){
if(l == r){
mx[rt] = mi[rt] = sum[l];
return;
}
int mid = (l + r) >> 1;
if(l <= mid) build(lson);
if(r > mid) build(rson);
mx[rt] = max(mx[rt<<1], mx[rt<<1|1]);
mi[rt] = min(mi[rt<<1], mi[rt<<1|1]);
}
ll query_min(int rt, int l, int r, int L, int R){
if(L <= l && r <= R) return mi[rt];
int mid = (l + r) >> 1;
ll ans = 3e18;
if(L <= mid) ans = min(ans, query_min(lson, L, R));
if(R > mid) ans = min(ans, query_min(rson, L, R));
return ans;
}
ll query_max(int rt, int l, int r, int L, int R){
if(L <= l && r <= R) return mx[rt];
int mid = (l + r) >> 1;
ll ans = -3e18;
if(L <= mid) ans = max(ans, query_max(lson, L, R));
if(R > mid) ans = max(ans, query_max(rson, L, R));
return ans;
}
int st[maxn], l[maxn], r[maxn], top;
int main()
{
scanf("%d", &n);
for(int i = 1 ; i <= n ; ++ i) scanf("%lld", &a[i]);
for(int i = 1 ; i <= n ; ++ i) scanf("%lld", &b[i]), sum[i] = sum[i-1] + b[i];
build(1, 0, n);
a[0] = -3e18; a[n + 1] = -3e18;
top = 0;
for(int i = 1 ; i <= n + 1 ; ++ i){
while(top && a[i] < a[st[top]]) r[st[top]] = i, top --;
st[++top] = i;
}
top = 0;
for(int i = n ; i >= 0 ; -- i){
while(top && a[i] < a[st[top]]) l[st[top]] = i, top --;
st[++top] = i;
}
ll ans = -3e18;
//for(int i = 0 ; i <= n ; ++ i) cout << sum[i] << ' ';
//cout << endl;
for(int i = 1 ; i <= n ; ++ i){
if(a[i] >= 0){
if(l[i] + 1 == i){
ans = max(ans, (query_max(1, 0, n, i, r[i] - 1) - sum[i-1]) * a[i]);
continue;
}
ans = max(ans, (query_max(1, 0, n, i, r[i] - 1) - query_min(1, 0, n, l[i], i - 1)) * a[i]);
}
else{
if(l[i] + 1 == i){
ans = max(ans, (query_min(1, 0, n, i, r[i] - 1) - sum[i-1]) * a[i]);
continue;
}
ans = max(ans, (query_min(1, 0, n, i, r[i] - 1) - query_max(1, 0, n, l[i], i - 1)) * a[i]);
}
}
cout << ans;
return 0;
}