题目
题意:给定两个数组
s
,
t
s,t
s,t,现重排列数组
s
s
s,使得数组
s
s
s小于
t
t
t。问有多少种排列方式。
#include<bits/stdc++.h>
using namespace std;
#define ll long long
const ll mod=998244353;
ll n,m,a[202020],cnt,p[202020],ans,fac[202020],inv[202020],s[202020],invfac[202020];
// 树状数组
ll lowbit(ll aa){
return aa&-aa;
}
void add(ll pos,ll x){
for(ll i=pos;i<=cnt;i+=lowbit(i)) s[i]+=x;
}
ll query(ll x){
ll res=0;
for(int i=x;i>0;i-=lowbit(i)) res+=s[i];
return res;
}
int main(){
cin>>n>>m;
for(int i=1;i<=n;i++){
int x;
cin>>x;
p[x]++;
cnt=max(cnt,(ll)x);
}
for(int i=1;i<=m;i++) cin>>a[i];
fac[0]=fac[1]=inv[0]=inv[1]=invfac[0]=invfac[1]=1;
for(int i=2;i<=max(m,n);i++){
fac[i]=fac[i-1]*i%mod;// fac[i] = i!
inv[i]=(mod-mod/i)*inv[mod%i]%mod;// inv[i] = 1/i
invfac[i]=invfac[i-1]*inv[i]%mod;// invfac[i] = 1/i!
}
// cntt = C(n,p[1]) * C(n-p[1],p[2]) * ... * C(n-p[1]...-p[cnt-1],p[cnt])
// 即数列的总个数 cntt = P(p[1],p[2],...,p[cnt])
int cntt=fac[n];
for(int i=1;i<=cnt;i++) cntt=cntt*invfac[p[i]]%mod,add(i,p[i]);
for(int i=1;i<=min(n,m);i++){
// 贡献:当前数列总个数 * (可取元素个数/剩余元素个数)
ans+=query(a[i]-1)*cntt%mod*inv[n-i+1]%mod,ans%=mod;
// 没有可更新元素,停止迭代计算
if(p[a[i]]==0) break;
// 更新下个迭代的数列总个数
cntt=cntt*inv[n-i+1]%mod*p[a[i]]%mod;
// 更新下个迭代的未使用元素数量
p[a[i]]--;
add(a[i],-1);
// n<m时,如果i能走到n,则额外计算1的贡献
if(n<m&&i==n) ans=(ans+1)%mod;
}
cout<<ans;
return 0;
}