三维偏序问题请看下面
Description
Input
第一行一个正整数 n
第二行 n 个数字,表示排列 a i
第三行 n 个数字,表示排列 b i
第四行 n 个数字,表示排列 c i
Output
一行一个整数,表示答案
Sample Input
8
1 7 5 3 4 8 2 6
3 1 2 7 4 8 5 6
6 3 4 5 8 2 1 7
Sample Output
42
Data Constraint
对于 10% 的数据满足 n ≤ 20
对于 30% 的数据满足 n ≤ 2000
另有 20% 的数据满足 a i = b i
对于 100% 的数据满足 n ≤ 100000
Solution
考虑S只包括了包含了三元组的位置,那么|S|<=3
若|S|=1,则答案为n
若|S|=2,则答案为
C2n
减去不合法的
在这种情况下,不合法的只有当某一列三个数都小于另一列,即这一列是没用的。
这个就是三维偏序问题了,具体怎么做,下面再说。
|S|=3,则答案为
C3n
减去不合法的
不合法的有三个最大值都集中在一列或两列
集中在一列,和上面的一样,三维偏序问题
集中在两列,枚举是哪两个数在同一列,然后变成二维偏序问题。
三维偏序问题
也可以理解为三维数点问题,就以这题中数一个点三维都小于另一个点为例。
二维偏序问题是很简单的:
一维排序,然后第二位用树状数组维护一下。
三维偏序也可以这么做,把树状数组变成树状数组套线段树,也不麻烦。
也可以考虑第二维分治,第三维用树状数组。
具体来说,先按照一维排序,在分治时分成两部分,只考虑左半边对右半边的贡献。
这时,可以左半边和右半边分别按第二维排序,因为保证了第一位左边比右边小。
然后只剩一维了,用树状数组维护一下。
递归做到底层就行了。
Code
#include<cstdio>
#include<cstring>
#include<algorithm>
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define N 101000
#define ll long long
#define lowbit(x) (x&(-x))
using namespace std;
int n,t[N*2];
ll ans,A[N],B=0,X=0;
struct node{
int a,b,c,z;
}a[N];
bool cnt1(node a,node b){return a.a<b.a;}
bool cnt2(node a,node b){return a.b<b.b;}
void ins(int x,int y)
{
for(;x<=n;x+=lowbit(x)) t[x]+=y;
}
int get(int x)
{
int ans=0;
for(;x;x-=lowbit(x)) ans+=t[x];
return ans;
}
void divide(int l,int r)
{
if(l==r) return;
int m=(l+r)/2;
divide(l,m);divide(m+1,r);
sort(a+l,a+r+1,cnt1);
sort(a+l,a+m+1,cnt2);
sort(a+m+1,a+r+1,cnt2);
int j=m+1;
fo(i,l,m)
{
ins(a[i].c,1);
while(j<=r&&a[j].b<a[i].b) j++;
while(j<=r&&a[j].b<a[i+1].b)
{
ll c=get(a[j].c);
ans-=c;
A[a[j].z]+=c;
j++;
}
}
while(j<=r)
{
ll c=get(a[j].c);
ans-=c;
A[a[j].z]+=c;
j++;
}
fo(i,l,m) ins(a[i].c,-1);
}
void calc()
{
memset(t,0,sizeof(t));
sort(a+1,a+n+1,cnt1);
fo(i,1,n)
{
A[a[i].z]+=get(a[i].b);
ins(a[i].b,1);
}
}
int main()
{
freopen("subset.in","r",stdin);
freopen("subset.out","w",stdout);
scanf("%d",&n);
fo(i,1,n) scanf("%d",&a[i].a);
fo(i,1,n) scanf("%d",&a[i].b);
fo(i,1,n) scanf("%d",&a[i].c),a[i].z=i;
ans=n;
ans=ans+ans*(ans-1)/2+ans*(ans-1)*(ans-2)/3/2;
sort(a+1,a+n+1,cnt1);
divide(1,n);
fo(i,1,n) X+=A[i]*(A[i]-1)/2;
ans-=X;
memset(A,0,sizeof(A));
calc();
fo(i,1,n) B+=A[i]*(A[i]-1)/2;
fo(i,1,n) swap(a[i].b,a[i].c),A[i]=0;
calc();
fo(i,1,n) B+=A[i]*(A[i]-1)/2;
fo(i,1,n) swap(a[i].a,a[i].c),A[i]=0;
calc();
fo(i,1,n) B+=A[i]*(A[i]-1)/2;
ans=ans-B+3*X;
printf("%lld\n",ans);
}