利用单调栈维护一个三元组 ( v a l i , L i , R i ) (val_i,L_i,R_i) (vali,Li,Ri)
对于每个值,我们都假设其为最小值;
看看这个值,作为最小值能取的区间是多少;
比如数据
6
3 1 6 4 5 2
有以下几个三元组
(
3
,
1
,
1
)
(3,1,1)
(3,1,1)
(
1
,
1
,
6
)
(1,1,6)
(1,1,6)
(
6
,
3
,
3
)
(6,3,3)
(6,3,3)
(
4
,
3
,
5
)
(4,3,5)
(4,3,5)
(
2
,
3
,
6
)
(2,3,6)
(2,3,6)
我们可以发现答案在
(
4
,
3
,
5
)
(4,3,5)
(4,3,5)这个三元组里
具体怎么操作呢?
我们这里维护一个单调递增栈;
当栈顶元素小于我们当前的元素 a i a_i ai,直接入栈;
当栈顶元素大于等于我们当前的元素 a i a_i ai时,进行区间扩张;
怎么扩张呢?
将栈顶元素记为 t o p i top_i topi;
将 t o p i top_i topi的右区间赋给栈顶下一个元素 t o p i − 1 top_{i-1} topi−1的右区间(注意判断有没有下一个元素)
将 t o p i top_i topi的左区间赋给当前元素 a i a_i ai的左区间
如此循环操作即可
最后,防止像这样的数据
3
7 8 9
我们需要将元素一个个弹出栈;
并按上面所说的,将 t o p i top_{i} topi的右区间赋值给 t o p i − 1 top_{i-1} topi−1的右区间;
稍微解释一下;
因为我们的元素区间是按顺序放置的;
并且我们维护的是单调栈;
也就是栈顶是大的,栈底是小的;
那么如果一个元素能新进来,那么说明从栈顶元素出生到当前新元素的生命结束,新元素的值都是最小的;
同理,栈下部的元素从其生命开始到栈上部元素的生命结束,它都是最小的;
不错,如此划分区间是会重叠的,但是符合题意;
手写栈的方式
#include <iostream>
#include <cstdio>
#include <unordered_map>
#include <algorithm>
#include <utility>
#include <cmath>
using namespace std;
typedef long long ll;
const int N = 100000+10;
ll a[N],l[N],r[N],sta[N],s[N];
struct Node{
ll val,l,r;
};
Node ans;
int main()
{
ans.l = ans.r = 1;
ll n;
cin >> n;
for(ll i=1;i<=n;++i) cin >> a[i],l[i]=r[i]=i,s[i]=s[i-1]+a[i];
ll top = 2;//[1~top)
sta[1] = 1;
for(ll i=2;i<=n;++i){
ll pos = sta[top-1];
while(top>1&&a[pos]>=a[i]){
ll val = (s[r[pos]]-s[l[pos]-1])*a[pos];
if(ans.val < val){
ans.val = val;
ans.l = l[pos];
ans.r = r[pos];
}
--top;
if(top>1){
r[sta[top-1]] = r[pos];
}
l[i] = l[pos];
pos = sta[top-1];
}
sta[top++] = i;
}
while(top>1){
ll pos = sta[top-1];
ll val = (s[r[pos]]-s[l[pos]-1])*a[pos];
if(ans.val < val){
ans.val = val;
ans.l = l[pos];
ans.r = r[pos];
}
--top;
if(top>1) r[sta[top-1]] = r[pos];
}
cout << ans.val << '\n';
cout << ans.l << ' ' << ans.r <<'\n';
return 0;
}
stl栈的方式
#include <iostream>
#include <stack>
using namespace std;
const int N = 100000+10;
typedef long long ll;
struct Node{
ll val,l,r;
}a[N];
ll s[N];
stack<Node> sta;
Node ans;
int main()
{
ans.l = ans.r = 1;
int n;
cin >> n;
for(int i=1;i<=n;++i){
cin >> a[i].val;
a[i].l=a[i].r=i;
s[i] = s[i-1] + a[i].val;
}
for(int i=1;i<=n;++i){
while(!sta.empty()&&a[i].val <= sta.top().val){
Node tmp = sta.top();
sta.pop();
ll val = tmp.val * (s[tmp.r] - s[tmp.l-1]);
if(ans.val < val){
ans.val = val;
ans.l = tmp.l;
ans.r = tmp.r;
}
if(!sta.empty()){
sta.top().r = tmp.r;
}
a[i].l = tmp.l;
}
sta.push(a[i]);
}
while(!sta.empty()){
Node tmp = sta.top();
sta.pop();
if(!sta.empty()){
sta.top().r = tmp.r;
}
ll val = tmp.val * (s[tmp.r] - s[tmp.l-1]);
if(ans.val < val){
ans.val = val;
ans.l = tmp.l;
ans.r = tmp.r;
}
}
cout << ans.val << '\n';
cout << ans.l << ' ' << ans.r <<'\n';
return 0;
}