题目大意
给你一个序列,这个序列中任意一个区间的最大值减去最小值得到一个数,让你求所有的区间的这个数加起来的和是多少。分析
这道题可以用线性扫描pre数组的方法差不多 O(n) 的复杂度做出来
我做的时候是分治的思想,用线段树来维护, O(nlogn) 的复杂度
找到最大的那个元素,设下标为loc,就将问题分成了三部分:
⎧⎩⎨包含这个最大元素的区间这个最大元素左边的区间最大元素右边的区间
通过线段树来维护一个区间中最大值的下标即可.
线性扫描的思路
用一个pre数组保存某个元素前面比它大的第一个元素的下标,和一个pre2数组保存某个元素后面比它大的第一个元素的下标。
有了这个信息我们就可以求得以这个元素为最大值的区间又多少个。
在预处理pre数组的时候线性扫一遍往前跳就行了,预处理复杂度大概是 O(n)代码
#include<cstdio>
#include<iostream>
#include<cmath>
#include<cstring>
#include<cstdlib>
#include<queue>
#include<map>
#include<algorithm>
#include<set>
#include<stack>
using namespace std;
#define LL long long int
const int INF=0x3f3f3f3f;
const int MAXN=1000005;
LL n;
LL a[MAXN];
LL maxn[MAXN*4];//区间最大值
LL minn[MAXN*4];
LL maxn_loc[MAXN*4];//最大值在a数组中的下标
LL minn_loc[MAXN*4];//最小值在a数组中的下标
void In()
{
scanf("%lld",&n);
for(int i=1;i<=n;i++)scanf("%lld",&a[i]);
}
void Pushup(int rt)
{
if(maxn[rt*2] > maxn[rt*2+1])
{
maxn[rt]=maxn[rt*2];
maxn_loc[rt]=maxn_loc[rt*2];
}
else
{
maxn[rt]=maxn[rt*2+1];
maxn_loc[rt]=maxn_loc[rt*2+1];
}
if(minn[rt*2] < minn[rt*2+1])
{
minn[rt]=minn[rt*2];
minn_loc[rt]=minn_loc[rt*2];
}
else
{
minn[rt]=minn[rt*2+1];
minn_loc[rt]=minn_loc[rt*2+1];
}
}
void Build(int l,int r,int rt)
{
if(l==r)
{
maxn[rt]=a[l];
maxn_loc[rt]=l;
minn[rt]=a[l];
minn_loc[rt]=l;
return ;
}
int m=(l+r)/2;
Build(l,m,rt*2);
Build(m+1,r,rt*2+1);
Pushup(rt);
}
LL Query_max(int L,int R,int l,int r,int rt)//返回最大值所在的下标
{
if(l>=L && r<=R)
{
return maxn_loc[rt];
}
LL ans=-INF;
int ans_loc;
int m=(l+r)/2;
if(L<=m)
{
int loc=Query_max(L,R,l,m,rt*2);
if(a[loc]>ans){ans_loc=loc;ans=a[loc];}
}
if(R>m)
{
int loc=Query_max(L,R,m+1,r,rt*2+1);
if(a[loc]>ans){ans_loc=loc;ans=a[loc];}
}
return ans_loc;
}
LL Query_min(int L,int R,int l,int r,int rt)//返回最大值所在的下标
{
if(l>=L && r<=R)
{
return minn_loc[rt];
}
LL ans=INF;
int ans_loc;
int m=(l+r)/2;
if(L<=m)
{
int loc=Query_min(L,R,l,m,rt*2);
if(a[loc]<ans){ans_loc=loc;ans=a[loc];}
}
if(R>m)
{
int loc=Query_min(L,R,m+1,r,rt*2+1);
if(a[loc]<ans){ans_loc=loc;ans=a[loc];}
}
return ans_loc;
}
LL Work(int l,int r,int is_max)//返回最大值
{
if(l>r)return 0;
if(l==r)return a[l];
LL ans=0;
LL loc;
if(is_max==1)loc=Query_max(l,r,1,n,1);
else loc=Query_min(l,r,1,n,1);
ans=a[loc]*(double)(loc-l+1)*(double)(r-loc+1);
ans+=Work(l,loc-1,is_max);
ans+=Work(loc+1,r,is_max);
return ans;
}
int main()
{
In();
Build(1,n,1);
printf("%lld\n",Work(1,n,1)-Work(1,n,0));
return 0;
}
/*
3
1 5 5
5
2 1 5 3 5
*/