大致题意
给n个二维数对(ai,bi),求将n个数对排列之后,ai,bi都不是单调不减的。这样的排列有多少个。
大致思路
考虑反过来求然后容斥一下,答案=总的排列数-(ai单调不减或者bi单调不减的排列数)+(ai,bi都单调不减的排列数);只考虑ai单调不减,只要看每一个ai对应了多少个bi,res1*=(ai对应的bi的个数)!。bi单调的时候同样的求法。最后多剪掉的那部分是,ai bi都是单调不减的时候,同一个ai,对应的bi可能有多个也相同,然后他们之间可以相互交换。具体见代码中的res3。
代码
#include<bits/stdc++.h>
using namespace std;
#define maxn 300005
#define maxm 1000006
#define ll long long int
#define INF 0x3f3f3f3f
#define inc(i,l,r) for(int i=l;i<=r;i++)
#define dec(i,r,l) for(int i=r;i>=l;i--)
#define mem(a) memset(a,0,sizeof(a))
#define sqr(x) (x*x)
#define inf (ll)2e18+1
#define mod 998244353
int read(){
int x=0,f=1;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
return f*x;
}
int n;
/*struct node{int x,y;}a[maxn];
bool cmp1(node p,node q){
if(p.x!=q.x)return p.x < q.x;
else return p.y < q.y;
}*/
vector<int>a[maxn],b[maxn];
ll f[maxn];
void init(){
f[0]=1;
for(ll i=1;i<maxn;i++)f[i]=f[i-1]*i%mod;
//fac[maxn-1]=fastpow(maxn-1,mod-2);
//for(ll i=maxn-2;i>=0;i--)fac[i]=fac[i+1]*(i+1)%mod;
}
int main()
{
init();
n=read();
int x,y;
inc(i,1,n){x=read();y=read();a[x].push_back(y);b[y].push_back(x);}
int last=0;
ll res1=1,res2=1,res3=1;
inc(i,1,n){
if(a[i].size()){
int siz=a[i].size();
//printf("siz = %d\n",siz);
res1=res1*f[siz]%mod;
//printf("res1 = %lld\n",res1);
if(res3==-1)continue;
sort(a[i].begin(),a[i].end());
if(last<=a[i].front()){
int pre=0,cnt=0;
inc(j,0,a[i].size()-1){
if(a[i][j]!=pre){
res3=res3*f[cnt]%mod;
pre=a[i][j];
cnt=1;
}
else {
cnt++;
}
}
if(cnt>1)res3=res3*f[cnt]%mod;
last=a[i].back();
}
else {
res3=-1;
}
}
}
inc(i,1,n){
if(b[i].size()){
int siz=b[i].size();
res2=res2*f[siz]%mod;
}
}
//printf("%lld %lld %lld\n",res1,res2,res3);
ll ans=(res1+res2)%mod;
if(res3>0)ans=(ans-res3+mod)%mod;
ans=(f[n]-ans+mod)%mod;
printf("%lld\n",ans);
return 0;
}
/*
3
1 3
2 3
2 3
*/