这是我见过的为数不多的良心九怜题之一。
题目大意
给定一个长度为$n$序列,你要在序列末尾加入$m$个$[L,R]$之间的数$m\leq 10^7,L,R\leq 10^9$,使得该序列猴子排序的轮数(一轮是指随机打乱整个序列,不断重复操作直到否有序)期望最大,求这个最大的期望。
题解
假设序列元素互不相同,那么有序的排列方式只有一个,而排列方式的数量有$n!$种,每轮成功的概率是$\frac {1}{n!}$,所以期望轮数是$n!$。
考虑第$i$种元素有$cnt_i$个的序列有多少个。
元素是互不相同的,所以$cnt_i$个第$i$个元素在有序序列中的相对位置是固定的,而元素在这些位置的排列是任意的,所有有序的排列方式有$\prod cnt_i!$个,期望是$\frac {n!}{\prod cnt_i!}$。
若使上式子最大,由于$(n+1)!(n-1)!>(n!)^2$很明显希望使得$\max cnt_i$最小且平均。
那么我们统计原序列中$[L,R]$之间的数,每次加入数量最小的元素,直到加入了$m$个数,最后计算答案即可,这个排一下序优化一下就行了。
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#define LL long long
#define M 200020
#define MAXN 10200010
#define mod 998244353
using namespace std;
namespace IO{
const int BS=(1<<20); int Top=0;
char Buffer[BS],OT[BS],*OS=OT,*HD,*TL,SS[20]; const char *fin=OT+BS-1;
char Getchar(){if(HD==TL){TL=(HD=Buffer)+fread(Buffer,1,BS,stdin);} return (HD==TL)?EOF:*HD++;}
void flush(){fwrite(OT,1,OS-OT,stdout);}
void Putchar(char c){*OS++ =c;if(OS==fin)flush(),OS=OT;}
void write(int x){
if(!x){Putchar('0');return;} if(x<0) x=-x,Putchar('-');
while(x) SS[++Top]=x%10,x/=10;
while(Top) Putchar(SS[Top]+'0'),--Top;
}
int read(){
int nm=0,fh=1; char cw=Getchar();
for(;!isdigit(cw);cw=Getchar()) if(cw=='-') fh=-fh;
for(;isdigit(cw);cw=Getchar()) nm=nm*10+(cw-'0');
return nm*fh;
}
}
using namespace IO;
int add(int x,int y){return (x+y>=mod)?x+y-mod:x+y;}
int mul(int x,int y){return (x==1||y==1)?x+y-1:(LL)x*(LL)y%mod;}
int qpow(int x,int sq){
int res=1;
for(;sq;sq>>=1,x=mul(x,x)) if(sq&1) res=mul(res,x);
return res;
}
int p[M],fac[MAXN],ifac[MAXN],cnt[M],t[M];
int main(){
fac[0]=1;
for(int i=1;i<MAXN;++i) fac[i]=mul(fac[i-1],i); ifac[MAXN-1]=qpow(fac[MAXN-1],mod-2);
for(int i=MAXN-1;i;--i) ifac[i-1]=mul(ifac[i],i);
for(int T=read();T;--T){
int n=read(),m=read(),l=read(),r=read();
int now=r-l+1,num=0,tot=1,ans,tmp=0,fin;
for(int i=1;i<=n;i++) cnt[i]=0,p[i]=read();
sort(p+1,p+n+1),ans=fac[n+m];
for(int i=1;i<=n;i++,tot++){
while(p[i]==p[i+1]&&i<n) i++,cnt[tot]++; cnt[tot]++;
if(p[i]>=l&&p[i]<=r) t[++tmp]=cnt[tot],now--;
else ans=mul(ans,ifac[cnt[tot]]);
} tot--,sort(t+1,t+tmp+1);
for(fin=1;fin<=tmp;fin++){
if((LL)(t[fin]-num)*(LL)now>(LL)m) break;
m-=(t[fin]-num)*now,now++,num=t[fin];
} num+=m/now,m%=now;
ans=mul(ans,mul(qpow(ifac[num+1],m),qpow(ifac[num],now-m)));
for(int i=fin;i<=tmp;i++) ans=mul(ans,ifac[t[i]]);
write(ans),Putchar('\n');
}flush(); return 0;
}