传送门:牛客
题目描述:
珂朵莉给了你一个序列,有n*(n+1)/2 个子区间,求出她们各自的逆序对个数,然后加起来输出
输入:
10
1 10 8 5 6 2 3 9 4 7
输出:
270
对于这道题,我们发现和逆序对有关,但是我们肯定不能直接求出每一个区间的逆序对数然后加起来,这样肯定会超时,那么我们想一下有没有什么优化的方案.
我们可以计算每一个逆序对所有可以贡献的个数.举个例子,比如我们有
a
[
i
]
>
a
[
j
]
,
i
<
j
a[i]>a[j],i<j
a[i]>a[j],i<j,那么此时我们的
i
i
i和
j
j
j形成了一个逆序对
那么此时所有区间
[
l
,
r
]
[l,r]
[l,r]包含了这个逆序对,只要满足
l
<
=
i
,
r
>
=
j
l<=i,r>=j
l<=i,r>=j,那么此时对于这个逆序对来说,产生的贡献一共有
i
∗
(
n
−
j
+
1
)
i*(n-j+1)
i∗(n−j+1)
那么我们可以选择枚举所有的 j j j,然后对于每一个 j j j,我们都找出所有 i i i的贡献.对于这个,我们可以使用线段树进行维护,对于每一个 i i i,我们发现他的贡献是由他的初始编号决定的,所以对于每一个 i i i,我们都在 a [ i ] a[i] a[i]位置添加 i i i的编号即可.对于所有 i i i的贡献就只需要对大于 j j j的所有点进行一个求和即可
对于区间操作,我们观察l,r的范围,发现达到了 1 e 9 1e9 1e9,显然我们的线段树不能直接维护这么大的数字,所以我们需要进行离散化操作(对于具体的离散化操作可以参考代码)
需要注意的是,本题会爆longlong,但是庆幸的是并不会爆__int128,所以我们可以使用__int128而不用敲麻烦的高精度了
下面是具体的代码部分:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
#define root 1,n,1
#define ls rt<<1
#define rs rt<<1|1
#define lson l,mid,rt<<1
#define rson mid+1,r,rt<<1|1
inline ll read() {
ll x=0,w=1;char ch=getchar();
for(;ch>'9'||ch<'0';ch=getchar()) if(ch=='-') w=-1;
for(;ch>='0'&&ch<='9';ch=getchar()) x=x*10+ch-'0';
return x*w;
}
#define int __int128
#define maxn 1000100
const double eps=1e-8;
#define int_INF 0x3f3f3f3f
#define ll_INF 0x3f3f3f3f3f3f3f3f
struct Segment_tree{
int l,r,sum;
}tree[maxn*4];
int n;int a[maxn];
void pushup(int rt) {
tree[rt].sum=tree[ls].sum+tree[rs].sum;
}
void build(int l,int r,int rt) {
tree[rt].l=l;tree[rt].r=r;
if(l==r) return ;
int mid=(l+r)>>1;
build(lson);build(rson);
pushup(rt);
}
void update(int pos,int v,int rt) {
if(tree[rt].l==pos&&tree[rt].r==pos) {
tree[rt].sum+=v;
return ;
}
int mid=(tree[rt].l+tree[rt].r)>>1;
if(pos<=mid) update(pos,v,ls);
else update(pos,v,rs);
pushup(rt);
}
int query(int l,int r,int rt) {
if(tree[rt].l==l&&tree[rt].r==r) {
return tree[rt].sum;
}
int mid=(tree[rt].l+tree[rt].r)>>1;
if(r<=mid) return query(l,r,ls);
else if(l>mid) return query(l,r,rs);
else return query(l,mid,ls)+query(mid+1,r,rs);
}
vector<int>v;
void print(int x) {
if(x>9) print(x/10);
putchar(x%10+'0');
}
signed main() {
n=read();
for(int i=1;i<=n;i++) {
a[i]=read();
v.push_back(a[i]);
}
sort(v.begin(),v.end());
v.erase(unique(v.begin(),v.end()),v.end());
int Size=v.size();
build(1,Size+10,1);
int ans=0;
for(int i=1;i<=n;i++) {
int x=lower_bound(v.begin(),v.end(),a[i])-v.begin()+1;
ans+=query(x+1,Size+10,1)*(n-i+1);
update(x,i,1);
}
print(ans);
return 0;
}