首先有个显然的结论,对于一个集合
S
S
,当均不为最大值时从
S
S
中删去,那么
|S|≤3
|
S
|
≤
3
且一个
S
S
唯一对应一个答案。
我们先用的三维偏序预处理出对于每个
x
x
,有多少个满足
ax>ay,bx>by,cx>cy
a
x
>
a
y
,
b
x
>
b
y
,
c
x
>
c
y
,记作
dx
d
x
。
类似的,用
O(nlogn)
O
(
n
log
n
)
的二维偏序预处理出有两维大于,另一维不限的个数,记作
dabx,dbcx,dacx
d
a
b
x
,
d
b
c
x
,
d
a
c
x
。
于是讨论
|S|
|
S
|
:
1.
|S|=1
|
S
|
=
1
,显然所有单独的列都为一种合法的
S
S
,贡献。
2.
|S|=2
|
S
|
=
2
,那么要从所有二元集合里去除某一列比另一列三维都大的情况,剩下的均为合法
S
S
,那么就是。
3.
|S|=3
|
S
|
=
3
,先考虑某一列三维均为最大的情况,记作
A
A
;再考虑某一列至少有两个最大值的情况,记作,那么
|A|=∑ni=1(di2),|B|=∑ni=1((dabi2)+(dbci2)+(daci2))
|
A
|
=
∑
i
=
1
n
(
d
i
2
)
,
|
B
|
=
∑
i
=
1
n
(
(
d
a
b
i
2
)
+
(
d
b
c
i
2
)
+
(
d
a
c
i
2
)
)
,但是
A
A
会在中算重三次,所以最后贡献就是
(n3)−|B|+2|A|
(
n
3
)
−
|
B
|
+
2
|
A
|
。
代码:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define N 100010
#define ll long long
using namespace std;
int n,c[N];
ll ans;
int read()
{
int x=0,f=1;char ch=getchar();
for(;ch<'0'||ch>'9';ch=getchar()) if(ch=='-') f=-1;
for(;ch>='0'&&ch<='9';ch=getchar()) x=x*10+ch-'0';
return x*f;
}
struct node
{
int x,y,z,d,dx,dy,dz;
}a[N],b[N];
void add(int x,int d){for(;x<=n;x+=(x&-x)) c[x]+=d;}
int qry(int x){int r=0;for(;x;x-=(x&-x)) r+=c[x];return r;}
bool cmpx(node p,node q)
{
return p.x<q.x;
}
bool cmpy(node p,node q)
{
return p.y<q.y;
}
ll G(int x)
{
return (ll)x*(x-1)/2;
}
void solve(int L,int R,int l,int r)
{
if(L>=R||l>=r) return ;
int mid=(l+r>>1);
for(int i=L;i<=R;i++)
if(a[i].y<=mid) add(a[i].z,1);
else a[i].d+=qry(a[i].z);
int top=L-1,tmp;
for(int i=L;i<=R;i++)
if(a[i].y<=mid) b[++top]=a[i],add(a[i].z,-1);
tmp=top;
for(int i=L;i<=R;i++)
if(a[i].y>mid) b[++top]=a[i];
for(int i=L;i<=R;i++)
a[i]=b[i];
solve(L,tmp,l,mid);
solve(tmp+1,R,mid+1,r);
}
int main()
{
n=read();
for(int i=1;i<=n;i++)
a[i].x=read();
for(int i=1;i<=n;i++)
a[i].y=read();
for(int i=1;i<=n;i++)
a[i].z=read();
sort(a+1,a+n+1,cmpx);
solve(1,n,1,n);
sort(a+1,a+n+1,cmpx);
memset(c,0,sizeof(c));
for(int i=1;i<=n;i++)
a[i].dz+=qry(a[i].y),add(a[i].y,1);
memset(c,0,sizeof(c));
for(int i=1;i<=n;i++)
a[i].dy+=qry(a[i].z),add(a[i].z,1);
memset(c,0,sizeof(c));
sort(a+1,a+n+1,cmpy);
for(int i=1;i<=n;i++)
a[i].dx+=qry(a[i].z),add(a[i].z,1);
ans=G(n)+n;
for(int i=1;i<=n;i++)
ans-=a[i].d;
ans+=G(n)*(n-2)/3;
for(int i=1;i<=n;i++)
{
ll A=G(a[i].d),B=G(a[i].dx)+G(a[i].dy)+G(a[i].dz);
ans-=(B-2*A);
}
printf("%lld",ans);
return 0;
}