题意
给出一个序列A,问能找出多少个四元组 (a,b,c,d) ,满足 a≠b≠c≠d,1≤a<b≤n,1≤c<d≤n,Aa<Ab,Ac>Ad 。
思路
先把输入离散化一下,然后,对于每个数,我们统计他前面比他大的数和比他小的数的个数,记为 pre_g 和 pre_l ; ∑ni=1pre_l[i]∗∑ni=1pre_g[i] 就是所有解,但是有重复,容斥一下,要删去
∑i=1n(pre_l[i]∗pre_g[i]+back_l[i]∗back_g[i]+pre_g[i]∗back_g[i]+pre_l[i]∗back_l[i])
代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 5e4+100;
typedef long long ll;
/*-------------------------------------*/
#define max(a,b) (a>b)?a:b
#define min(a,b) (a>b)?b:a
#define lson l , m , rt << 1
#define rson m + 1 , r , rt << 1 | 1
int num[maxn];
int sum[maxn << 2];
ll back_g[maxn],pre_g[maxn],back_l[maxn],pre_l[maxn];
void pushup(int rt){
sum[rt] = sum[rt << 1] + sum[rt << 1 | 1];
}
void update(int id,int l, int r, int rt){
if (l==r){
sum[rt] += 1;
return ;
}
int m = (l+r) >>1;
if (id <= m) {
update(id, lson);
}
if (m < id) {
update(id, rson);
}
pushup(rt);
}
int query(int L, int R, int l, int r, int rt){
if (L <= l && r <= R){
return sum[rt];
}
int m = (l + r) >> 1;
int ret = 0;
if (L <= m) {
ret += query(L , R , lson);
}
if (m < R) {
ret += query(L , R , rson);
}
return ret;
}
/*----------------------------------*/
int ran[maxn];
int num_ha[maxn];
int cmp(int x, int y) {
return num_ha[x] < num_ha[y];
}
int main(){
int n;
while(~scanf("%d", &n)){
for (int i = 0; i < n; i++){
scanf("%d", &num_ha[i]);
ran[i] = i;
}
sort(ran,ran+n,cmp);
num[ran[0]+1] = 1;
for(int i = 1 ; i < n ; i ++){
if(num_ha[ran[i]] == num_ha[ran[i-1]])
num[ran[i]+1] = num[ran[i-1]+1];
else num[ran[i]+1] = i+1;
}
memset(sum, 0, sizeof(sum));
for(int i = 1 ; i <= n ; i ++){
pre_g[i] = query(num[i]+1, n,1,n,1);
update(num[i], 1, n, 1);
}
memset(sum, 0, sizeof(sum));
for(int i = 1 ; i <= n ; i ++){
pre_l[i] = i-query(num[i], n,1,n,1)-1;
update(num[i], 1, n, 1);
}
memset(sum, 0, sizeof(sum));
for(int i = n ; i >= 1 ; i --){
back_g[i] = query(num[i]+1, n,1,n,1);
update(num[i], 1, n, 1);
}
memset(sum, 0, sizeof(sum));
for(int i = n ; i >= 1 ; i --){
back_l[i] = n - i - query(num[i], n,1,n,1);
update(num[i], 1, n, 1);
}
ll cnt1 = 0, cnt2 = 0;
for(int i = 1; i <= n; i++) {
cnt1 += pre_l[i];
cnt2 += pre_g[i];
}
ll ans = cnt1 * cnt2;
for(int i = 1; i <= n; i++) {
ans -= pre_l[i] * pre_g[i];
ans -= back_l[i] * back_g[i];
ans -= pre_g[i] * back_g[i];
ans -= pre_l[i] * back_l[i];
}
printf("%I64d\n", ans);
}
}