该题要求满足 i<j<k 且 ai<aj>ak 的三元组(i,j,k)的个数。
对于经典的逆序对的一种求解方法是对于元素 ai 求出满足 aj>ai 且 i<j 的元素的个数,线段树,树状数组以及平衡树都可以支持这个操作,用平衡树简单清晰,只需要依次插入每个元素并求一下当前平衡树中大于 ai 的元素的个数累加进答案即可。
对于本题只不过多求一遍。
1:顺序插入每个元素,插入前求出当前元素在当前平衡树中的rank记录为 b[i] 。
2:逆序插入每个元素,插入前求出当前元素在当前平衡树中的rank记录为 c[i] 。
3:根据乘法原理,ans=Σ b[i]*c[i]。
我写这道题只是为了练一下splay双旋的板子,没想到把splay操作写挂了,而且挂的离谱,我自己看到后都震惊了,写程序的时候得想着啥才能把splay写成这样......
第一次交的时候的splay代码。
void splay(int x,int &p) {
int y=f[x],z=f[y],q=f[p];
while(y!=q) {
if(z==q) x==l[y]?r_rot(x):l_rot(x);
else if(x==l[y]&&y==l[z]) r_rot(y),r_rot(x);
else if(x==l[y]&&y==r[x]) r_rot(x),l_rot(y);
else if(x==r[x]&&y==l[x]) l_rot(x),r_rot(x);
else l_rot(y),l_rot(x); y=f[x],z=f[y];
}
p=x;
}
那是一坨什么东西......
// q.c
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
using namespace std;
const int M=50000+10;
struct SplayTree {
int root,cnt,l[M],r[M],f[M],v[M],s[M];
void clear() {
root=cnt=0;
memset(l,0,sizeof(l));
memset(r,0,sizeof(r));
memset(f,0,sizeof(f));
memset(v,0,sizeof(v));
memset(s,0,sizeof(s));
}
void update(int x) {
s[x]=s[l[x]]+s[r[x]]+1;
}
void l_rot(int x) {
int y=f[x],z=f[y]; f[x]=z;
if(z) y==l[z]?l[z]=x:r[z]=x;
if(l[x]) f[l[x]]=y;
r[y]=l[x],f[y]=x,l[x]=y;
update(y),update(x);
}
void r_rot(int x) {
int y=f[x],z=f[y]; f[x]=z;
if(z) y==l[z]?l[z]=x:r[z]=x;
if(r[x]) f[r[x]]=y;
l[y]=r[x],f[y]=x,r[x]=y;
update(y),update(x);
}
void splay(int x,int &p) {
int y=f[x],z=f[y],q=f[p];
while(y!=q) {
if(z==q) x==l[y]?r_rot(x):l_rot(x);
else if(x==l[y]&&y==l[z]) r_rot(y),r_rot(x);
else if(x==l[y]&&y==r[z]) r_rot(x),l_rot(x);
else if(x==r[y]&&y==l[z]) l_rot(x),r_rot(x);
else l_rot(y),l_rot(x); y=f[x],z=f[y];
}
p=x;
}
void insert(int &x,int fa,int k) {
if(!x) x=++cnt,f[x]=fa,v[x]=k,s[x]=1;
else {
if(k<=v[x]) insert(l[x],x,k);
else insert(r[x],x,k);
update(x);
}
}
int query(int k) {
int ans=0,x=root,px=root;
while(x) {
if(v[x]<k) ans+=s[l[x]]+1,px=x,x=r[x];
else x=l[x];
}
splay(px,root);
return ans;
}
}t;
int n,a[M],b[M],c[M]; long long ans;
int main() {
freopen("queueb.in","r",stdin);
freopen("queueb.out","w",stdout);
scanf("%d",&n);
for(int i=1;i<=n;i++) scanf("%d",&a[i]);
t.clear();
for(int i=1;i<=n;i++) {
b[i]=t.query(a[i]);
t.insert(t.root,0,a[i]);
}
t.clear();
for(int i=n;i>=1;i--) {
c[i]=t.query(a[i]);
t.insert(t.root,0,a[i]);
}
for(int i=1;i<=n;i++) ans+=b[i]*c[i];
printf("%lld\n",ans);
return 0;
}