题目链接:
World is Exploding
Time Limit: 2000/1000 MS (Java/Others)
Memory Limit: 65536/65536 K (Java/Others)
Problem Description
Given a sequence A with length n,count how many quadruple (a,b,c,d) satisfies:
a≠b≠c≠d,1≤a<b≤n,1≤c<d≤n,Aa<Ab,Ac>Ad.
Input
The input consists of multiple test cases.
Each test case begin with an integer n in a single line.
The next line contains n integers A1,A2⋯An.
1≤n≤50000
0≤Ai≤1e9
Each test case begin with an integer n in a single line.
The next line contains n integers A1,A2⋯An.
1≤n≤50000
0≤Ai≤1e9
Output
For each test case,output a line contains an integer.
Sample Input
4
2 4 1 3
4
1 2 3 4
Sample Output
1
0
题意:
问符合题目给的四元组有多少个;
思路:
容斥,先算出a,b,c,d满足Aa<Ab&&Ac<Ad的个数,再减去a==c,a==d,b==c,b==d的个数,就是答案了,因为不可能有两个相等的出现;然后就是用树状数组
求pres[i],preb[i],nexs[i],nexb[i];分别表示第i个数前边比它小,比它大,后面比它小比它大的个数具体的看代码吧;
AC代码:
/************************************************
┆ ┏┓ ┏┓ ┆
┆┏┛┻━━━┛┻┓ ┆
┆┃ ┃ ┆
┆┃ ━ ┃ ┆
┆┃ ┳┛ ┗┳ ┃ ┆
┆┃ ┃ ┆
┆┃ ┻ ┃ ┆
┆┗━┓ ┏━┛ ┆
┆ ┃ ┃ ┆
┆ ┃ ┗━━━┓ ┆
┆ ┃ AC代马 ┣┓┆
┆ ┃ ┏┛┆
┆ ┗┓┓┏━┳┓┏┛ ┆
┆ ┃┫┫ ┃┫┫ ┆
┆ ┗┻┛ ┗┻┛ ┆
************************************************ */
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <bits/stdc++.h>
#include <stack>
using namespace std;
#define For(i,j,n) for(int i=j;i<=n;i++)
#define mst(ss,b) memset(ss,b,sizeof(ss));
typedef long long LL;
template<class T> void read(T&num) {
char CH; bool F=false;
for(CH=getchar();CH<'0'||CH>'9';F= CH=='-',CH=getchar());
for(num=0;CH>='0'&&CH<='9';num=num*10+CH-'0',CH=getchar());
F && (num=-num);
}
int stk[70], tp;
template<class T> inline void print(T p) {
if(!p) { puts("0"); return; }
while(p) stk[++ tp] = p%10, p/=10;
while(tp) putchar(stk[tp--] + '0');
putchar('\n');
}
const LL mod=1e9+7;
const double PI=acos(-1.0);
const int inf=1e9;
const int N=5e4+10;
const int maxn=1e3+14;
const double eps=1e-8;
int n,fa[N],pres[N],preb[N],nexs[N],nexb[N],sum[N];
struct node
{
int a,id;
}po[N];
int cmp(node x,node y)
{
if(x.a==y.a)return x.id<y.id;
return x.a<y.a;
}
int cmp1(node x,node y)
{
if(x.a==y.a)return x.id<y.id;
return x.a>y.a;
}
int lowbit(int x){return x&(-x);}
inline void update(int x)
{
while(x<=n)
{
sum[x]++;
x+=lowbit(x);
}
}
int query(int x)
{
int s=0;
while(x)
{
s+=sum[x];
x-=lowbit(x);
}
return s;
}
int main()
{
while(scanf("%d",&n)!=EOF)
{
For(i,1,n)read(po[i].a),po[i].id=i;
sort(po+1,po+n+1,cmp);
po[0].a=-1;
mst(sum,0);
For(i,1,n)
{
if(po[i].a==po[i-1].a)fa[i]=fa[i-1];
else fa[i]=i;
pres[po[i].id]=query(po[i].id)-(i-fa[i]);
nexs[po[i].id]=fa[i]-1-pres[po[i].id];
update(po[i].id);
}
mst(sum,0);
sort(po+1,po+n+1,cmp1);
For(i,1,n)
{
if(po[i].a==po[i-1].a)fa[i]=fa[i-1];
else fa[i]=i;
preb[po[i].id]=query(po[i].id)-(i-fa[i]);
nexb[po[i].id]=fa[i]-1-preb[po[i].id];
update(po[i].id);
}
sort(po+1,po+n+1,cmp1);
LL ans1=0,ans2=0,ans;
For(i,1,n)
{
ans1=ans1+pres[i];
ans2=ans2+preb[i];
}
ans=ans1*ans2;
For(i,1,n)
{
ans=ans-nexs[i]*nexb[i];//a==c
ans=ans-preb[i]*nexb[i];//a==d
ans=ans-pres[i]*nexs[i];//b==c
ans=ans-pres[i]*preb[i];//b==d
}
cout<<ans<<endl;
}
return 0;
}