题意:有一个A序列,A序列的长度为N。在这个序列中有一些子集(a,b,c,d),满足a!=b!=c!=d,1<=a<b<=n,1<=c<b<=n,Aa<Ab,Ac>Ad。最后输出这些序列的个数。
思路:容斥原理和数据离散化和树状数组
(1)frontSmall[k]表示在下标k之前比A[k]小的个数 (用树状数组求)
(2)frontLarge[k]表示在下标k之前比A[k]大的个数
(3)backSmall[k]表示在下标k之后比A[k]小的个数
(4)backLarge[k]表示在下标k之后比A[k]小的个数
(5)先求出一共有多少对Aa<Ab和Ac>Ad,然后相乘。
(6)去掉一些重复的,一共有四种重复的情况:
【1】a == c (b一定不等于d)
Aa < Ab
Aa > Ad
【2】a == d
Aa < Ab
Ac > Aa
【3】b == c
Aa < Ab
Ab < Ad
【4】b == d(a一定不等于c)
Aa < Ab
Ac > Ab
#include <iostream>
#include <stdio.h>
#include <string.h>
#include <queue>
#include <algorithm>
#include <stack>
#include <map>
#include <set>
#include <vector>
#include <math.h>
#define inf 0x3f3f3f3f
#define ll long long
#define mod 99991
#define N 51000
using namespace std;
int n;
ll a[N];
ll b[N];
ll sum[N];
ll frontLarge[N];
ll frontSmall[N];
ll backLarge[N];
ll backSmall[N];
int lowbit(int k){
return k&(-k);
}
void updata(int k){
while(k <= n){
sum[k]++;
k += lowbit(k);
}
}
ll query(int k){
ll ans = 0;
while(k > 0){
ans += sum[k];
k -= lowbit(k);
}
return ans;
}
int main(){
while(~scanf("%d",&n)){
for(int i = 0; i < n; ++i){
scanf("%lld",&a[i]);
b[i] = a[i];
}
sort(b,b+n);
int len = unique(b,b+n)-b;
memset(sum,0,sizeof(sum));
ll x1=0,x2=0;
for(int i = 0; i < n; ++i){
a[i] = lower_bound(b,b+n,a[i])-b+1;
frontSmall[i] = query(a[i]-1);
frontLarge[i] = query(n)-query(a[i]);
x1 += frontSmall[i];
updata(a[i]);
}
memset(sum,0,sizeof(sum));
for(int i = n-1; i >= 0; --i){
backSmall[i] = query(a[i]-1);
backLarge[i] = query(n)-query(a[i]);
x2 += backSmall[i];
updata(a[i]);
}
ll ans = x1*x2;
for(int i = 0; i < n; ++i){
ans -= frontSmall[i]*backSmall[i];
ans -= backLarge[i]*backSmall[i];
ans -= frontSmall[i]*frontLarge[i];
ans -= backLarge[i]*frontLarge[i];
}
printf("%lld\n",ans);
}
return 0;
}