Problem Statement
You are given an integer sequence
A
=
(
A
1
,
A
2
,
…
,
A
N
)
A = (A_1, A_2, \dots, A_N)
A=(A1,A2,…,AN).
Calculate the following expression:
∑ i = 1 N ∑ j = i + 1 N max ( A j − A i , 0 ) \displaystyle \sum_{i=1}^N \sum_{j=i+1}^N \max(A_j - A_i, 0) i=1∑Nj=i+1∑Nmax(Aj−Ai,0)
The constraints guarantee that the answer is less than 2 63 2^{63} 263.
题意解析
题目要求我们求出
i
i
i之后的所有
m
a
x
(
0
,
a
[
j
]
−
a
[
i
]
)
max(0,a[j]-a[i])
max(0,a[j]−a[i])的值,我们可以将其转化以下,转化为求出
j
j
j之前的所有
a
[
i
]
<
a
[
j
]
a[i]<a[j]
a[i]<a[j]的值,并且记录有多少个这样的值
我们的答案
a
n
s
ans
ans即为
c
n
t
∗
a
[
j
]
−
s
u
m
(
a
[
i
]
)
cnt*a[j]-sum(a[i])
cnt∗a[j]−sum(a[i])
思路:
线段树维护所有小于
a
[
i
]
a[i]
a[i]的区间和,由于题目范围在
1
e
8
1e8
1e8我们可以将所有数值离散化,在离散化的数组上建立线段树,其中离散化的数组表示的含义为
a
[
i
]
a[i]
a[i]表示
i
i
i这个数出现过多少次,存储
i
i
i的倍数,由此我们求所有小于
i
i
i的区间和,即为查询
1
到
i
−
1
1到~i-1
1到 i−1中的区间和,我们查询完后需要修改当前位置上的数值
线段树的插入顺序可以帮我们维护数列的顺序
线段树的查询操作来获取数列的区间和以及区间中数的数量
代码解析
- 线段树基本操作
struct node
{
int l,r;
ll sum=0;//维护区间和
int cnt=0;//维护区间数量
}tr[N*4];
//pushup基本操作
void pushup(int u)
{
tr[u].sum=tr[u<<1].sum+tr[u<<1|1].sum;
tr[u].cnt=tr[u<<1].cnt+tr[u<<1|1].cnt;
}
//建立线段树,此时我们未插入任何元素
void build(int u,int l,int r)
{ tr[u]={l,r};
if(l==r)return;
else{
int mid=l+r>>1;
build(u<<1,l,mid);
build(u<<1|1,mid+1,r);
}
}
//查询操作,我们即需要cnt,又需要sum,所以我们可以返回节点
node query(int u,int l,int r)
{
if(tr[u].l>=l&&tr[u].r<=r)
{
return tr[u];//此范围直接返回
}
int mid=tr[u].l+tr[u].r>>1;
node temp{0,0},left,right;
if(l<=mid) left=query(u<<1,l,r);//递归查询左边
if(r>mid)right=query(u<<1|1,l,r);//递归查询右边
temp.sum=left.sum+right.sum;//计算当前子树
temp.cnt=left.cnt+right.cnt;//计算
return temp;
}
void modify(int u,int l,int k)
{
if(tr[u].l==l&&tr[u].r==l)
{
tr[u].sum+=k;
tr[u].cnt++;
return;
}
else{
int mid=tr[u].l+tr[u].r>>1;
if(l<=mid) modify(u<<1,l,k);
else modify(u<<1|1,l,k);
pushup(u);
}
}
- 离散化基本操作
离散化是将原来的大范围映射到一个小范围区间里去,从而更方便我们进行下一步操作,此题我们是将a的范围缩小
离散化模板
//表示将a1数组排序去重,方便我们一一映射
sort(a1.begin(),a1.end());
a1.erase(unique(a1.begin(),a1.end()),a1.end());
//find函数,表示寻找当前数字下标
int find(int x)//返回下表从1开始
{
return lower_bound(a1.begin(),a1.end(),x)-a1.begin()+1;
}
完整代码
#include<iostream>
#include<cstring>
#include<algorithm>
#include<vector>
using namespace std;
const int N=4e6+10;
typedef long long ll;
ll a[N];
int b[N];
vector<ll> a1;
int n;
ll res=0;
struct node
{
int l,r;
ll sum=0;
int cnt=0;
}tr[N*4];
int find(int x)
{
return lower_bound(a1.begin(),a1.end(),x)-a1.begin()+1;
}
void pushup(int u)
{
tr[u].sum=tr[u<<1].sum+tr[u<<1|1].sum;
tr[u].cnt=tr[u<<1].cnt+tr[u<<1|1].cnt;
}
void build(int u,int l,int r)
{ tr[u]={l,r};
if(l==r)return;
else{
int mid=l+r>>1;
build(u<<1,l,mid);
build(u<<1|1,mid+1,r);
}
}
node query(int u,int l,int r)
{
if(tr[u].l>=l&&tr[u].r<=r)
{
return tr[u];
}
int mid=tr[u].l+tr[u].r>>1;
node temp{0,0},left,right;
if(l<=mid) left=query(u<<1,l,r);
if(r>mid)right=query(u<<1|1,l,r);
temp.sum=left.sum+right.sum;
temp.cnt=left.cnt+right.cnt;
return temp;
}
void modify(int u,int l,int k)
{
if(tr[u].l==l&&tr[u].r==l)
{
tr[u].sum+=k;
tr[u].cnt++;
return;
}
else{
int mid=tr[u].l+tr[u].r>>1;
if(l<=mid) modify(u<<1,l,k);
else modify(u<<1|1,l,k);
pushup(u);
}
}
int main()
{
cin>>n;
for(int i=1;i<=n;i++)
{
cin>>a[i];
a1.push_back(a[i]);
}
sort(a1.begin(),a1.end());
a1.erase(unique(a1.begin(),a1.end()),a1.end());
build(1,1,a1.size());
for(int i=1;i<=n;i++)
{
int x=find(a[i]);
//node root;
node root=query(1,1,x-1);
res+=root.cnt*a[i]-root.sum;
modify(1,x,a[i]);
}
cout<<res<<endl;
return 0;
}