题目:
http://acm.hdu.edu.cn/showproblem.php?pid=5792
题意:
找到数列中所有满足:
a≠b≠c≠d,1≤a<b≤n,1≤c<d≤n,Aa<Ab,Ac>Ad
.
的四元组个数。
思路:
明显只要线段树处理出每个数字前后比它大比它小的数字个数,类似求逆序数,然后直接计算即可把所有的升序对和逆序对乘积,然后减去每个数字重复出现的对数,就是减去它左边比他小的个数乘以右边比它小的个数,再减去左边比它大的个数乘以右边比它大的个数,再减去单边比它小的个数和比它大的个数的乘积,一共四种重复情况。有点卡常数。
代码:
//kopyh
#include <bits/stdc++.h>
#define N 50010
using namespace std;
int n,m,res,flag;
#define root 1 , n , 1
#define lson l , m , rt << 1
#define rson m + 1 , r , rt << 1 | 1
struct node
{
int pos,val;
node(int x=0,int y=0){pos=x,val=y;}
friend bool operator < (node a, node b)
{
return a.val < b.val;
}
friend node operator + (node a, node b)
{
return node(a.pos,a.val+b.val);
}
}arr[N<<2];
int add[N<<2],tot;
void pushUp(int rt)
{
arr[rt] = arr[rt<<1]+arr[rt<<1|1];
}
void pushDown(int l,int r,int rt)
{
if(add[rt])
{
int m = (l+r)>>1;
add[rt<<1] += add[rt];
add[rt<<1|1] += add[rt];
arr[rt<<1].val += add[rt];
arr[rt<<1|1].val += add[rt];
add[rt] = 0;
}
}
void updata(int l,int r,int rt,int ql,int qr,int val)
{
if(l>qr||ql>r)return;
if(l>=ql&&r<=qr)
{
arr[rt].val += val;
add[rt] += val;
return;
}
pushDown(l,r,rt);
int m = (l+r)>>1;
if(ql<=m)updata(lson,ql,qr,val);
if(qr>m)updata(rson,ql,qr,val);
pushUp(rt);
}
void build(int l,int r,int rt)
{
add[rt]=0;
if(l == r)
{
arr[rt].val = 0;
return;
}
int m = (l+r)>>1;
build(lson);
build(rson);
pushUp(rt);
}
node query(int l,int r,int rt,int ql,int qr)
{
if(l>qr||ql>r)
return node(0,0);
if(l>=ql&&r<=qr)
return arr[rt];
pushDown(l,r,rt);
int m = (l+r)>>1;
return query(lson,ql,qr)+query(rson,ql,qr);
}
int ld[N],lx[N],rd[N],rx[N];
int a[N],b[N],sum[N],num[N];
map<int,int>mp;
int main()
{
int i,j,k,cas,T,t;
while(scanf("%d",&m)!=EOF)
{
mp.clear();
for(i=0;i<m;i++)
{
scanf("%d",&a[i]);
b[i]=a[i];
}
sort(b,b+m);
sum[0]=num[0]=0;
for(i=0,j=0;i<m;i++)
{
if(!mp[b[i]])j++,mp[b[i]]=j,num[j]=1,sum[j]=sum[j-1]+num[j-1];
else num[j]++;
}
n=j;
build(root);
for(i=0;i<m;i++)
{
t=mp[a[i]];
updata(root,t,t,1);
lx[i]=query(root,1,t-1).val;
ld[i]=query(root,t+1,n).val;
rx[i]=sum[t]-lx[i];
rd[i]=m-lx[i]-ld[i]-rx[i]-num[t];
}
long long res=0,x=0,y=0;
for(i=0;i<m;i++)
x+=lx[i],y+=ld[i];
res = x*y;
for(i=0;i<m;i++)
res-=lx[i]*rx[i], res-=ld[i]*rd[i], res-=lx[i]*ld[i], res-=rd[i]*rx[i];
printf("%I64d\n",res);
}
return 0;
}