题目大意
对于 10% 的数据满足 n ≤ 20
对于 30% 的数据满足 n ≤ 2000
另有 20% 的数据满足 a i = b i
对于 100% 的数据满足 n ≤ 100000
解题思路
对于一个合法的三元组,我们发现S里有用的元素至多为3个,分别拥有a,b,c的最大值,甚至只有一个或两个元素。
发现如果只留下有用的元素,S和合法三元组一一对应,那么我们只需要统计最简S的个数即可。
分情况讨论。下面记a,b,c为三个维度。
- |S|=1,有n个,就是1..n
- |S|=2,一对下标(x,y),假如不是某一个a,b,c都比另一个大,例如ax>ay,bx>by,cx>cyax>ay,bx>by,cx>cy就合法。
- |S|=3,直接算有点麻烦,考虑容斥出来,不合法的情况有两种
- 一个的a,b,c全都大于另外两个的abc,记个数为A.
- 一个的其中两维,比如a,b为3个中的最大值,一个的另一维,比如c为最大值。这里又分3种情况。记个数为B.
B其实不太好算,我们考虑计算x拥有其中两维最大值,另一维没有限制的(x,y,z)无序三元组。枚举有限制的两维计算答案,记个数为X,那么X-3*A就是B了。
B是二维偏序统计问题,即对于每个x统计满足ax>ay,bx>byax>ay,bx>by的y的个数cnt,然后贡献就是C2cntCcnt2
A和|S|=2的情况是三维偏序问题,可以排序一维,分治一维,数据结构维护一维。
这样就做完了。
代码
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<map>
#include<set>
using namespace std;
typedef long long ll;
typedef double db;
#define fo(i,j,k) for(i=j;i<=k;i++)
#define fd(i,j,k) for(i=j;i>=k;i--)
#define cmax(a,b) (a=(a>b)?a:b)
#define cmin(a,b) (a=(a<b)?a:b)
const int N=1e5+5,M=5e6+5,mo=1e9+7;
int a[3][N],b[3][N],d[N],id[N],tr[N],n,V,i,j,pp,st,x,tmp,cnt[N];
ll ans,cnt1,cnt2,cnt3,X;
void ins(int x,int v)
{
while (x<=n)
{
tr[x]+=v;
x+=x&(-x);
}
}
int get(int x)
{
int ret=0;
while (x)
{
ret+=tr[x];
x-=x&(-x);
}
return ret;
}
bool cmp(int x,int y)
{
return a[V][x]<a[V][y];
}
void solve(int l,int r)
{
if (l==r) return ;
int m=l+r>>1;
solve(l,m);
solve(m+1,r);
V=1;
sort(id+l,id+1+r,cmp);
fo(i,l,r)
{
x=id[i];
if (x<=m) ins(a[2][x],1);
else cnt[x]+=get(a[2][x]);
}
fo(i,l,r) if (id[i]<=m) ins(a[2][id[i]],-1);
}
int main()
{
freopen("t2.in","r",stdin);
freopen("subset.out","w",stdout);
scanf("%d\n",&n);
fo(j,0,2) fo(i,1,n) scanf("%d",a[j]+i);
pp=1;
fo(i,1,n)
{
id[i]=i;
if(a[0][i]!=a[1][i]) pp=0;
}
V=0;
sort(id+1,id+1+n,cmp);
fo(j,0,2) fo(i,1,n) b[j][i]=a[j][id[i]];
fo(j,0,2) fo(i,1,n) a[j][i]=b[j][i];
fo(i,1,n) id[i]=i;
solve(1,n);
fo(i,1,n)
{
cnt1+=cnt[i];
cnt2+=1ll*cnt[i]*(cnt[i]-1)/2;
}
fo(V,0,2)
{
sort(id+1,id+1+n,cmp);
fo(i,1,n)
{
tmp=get(a[(V+1)%3][id[i]]);
X+=1ll*tmp*(tmp-1)/2;
ins(a[(V+1)%3][id[i]],1);
}
fo(i,1,n) ins(a[(V+1)%3][id[i]],-1);
}
cnt3=X-3*cnt2;
ans=1ll*n*(n-1)*(n-2)/6-cnt3-cnt2;
cnt1=1ll*n*(n-1)/2-cnt1;
ans+=n+cnt1;
printf("%lld",ans);
}