题意
有
n
n
n个小球,每种小球有
c
i
c_i
ci个。将小球进行排列,贡献为
∏
1
l
i
!
\prod \frac{1}{l_i!}
∏li!1,其中
l
i
l_i
li表示一个首尾相接后的连续段(一段连续相同的小球)。
问所有的排列(循环同构算不同方案)。
n
≤
1
e
5
n\leq 1e5
n≤1e5
∑
a
i
≤
2
e
5
\sum a_i \le 2e5
∑ai≤2e5
正解
计算 m m m个颜色相同的球被分成了 n n n段的贡献。显然为 ( e x − 1 ) m [ x n ] (e^x-1)^m[x^n] (ex−1)m[xn]
考虑颜色为 i i i的球的贡献。
设 F k = ( e x − 1 ) k [ x c i ] F_k=(e^x-1)^{k}[x^{c_i}] Fk=(ex−1)k[xci],即颜色为 i i i的球分成 k k k段的贡献。
设颜色为 i i i的球分成了 a i a_i ai段,计算它们的排列数。很明显如果直接 ( ∑ a i ) ! ∏ a i ! \frac{(\sum a_i)!}{\prod a_i!} ∏ai!(∑ai)!是不行的,因为这样可能存在颜色相同的段靠在一起。
这时候肯定要容斥。对于颜色为 i i i的球来说,上面的这个东西其实是算了至多分了 a i a_i ai段的情况。
那么对于一个实际上 k k k段的球来说,它被算到的次数是 ∑ j ≥ k C ( j − 1 , k − 1 ) ∗ K \sum_{j\geq k}C(j-1,k-1)*K ∑j≥kC(j−1,k−1)∗K次(其中 K K K表示的是一个相同的系数)。如果被 j j j算一次的系数为 f j f_j fj,那么我们希望 ∑ j ≥ k C ( j − 1 , k − 1 ) f j = F k \sum_{j\geq k}C(j-1,k-1)f_j=F_k ∑j≥kC(j−1,k−1)fj=Fk,即 ∑ j ≥ k C ( j , k ) f j + 1 = F k + 1 \sum_{j\geq k}C(j,k)f_{j+1}=F_{k+1} ∑j≥kC(j,k)fj+1=Fk+1。二项式反演得 f k + 1 = ∑ j ≥ k C ( j , k ) ( − 1 ) j − k F j + 1 f_{k+1}=\sum_{j\geq k}C(j,k)(-1)^{j-k}F_{j+1} fk+1=∑j≥kC(j,k)(−1)j−kFj+1。
利用 f f f就可以直接算了。于是就变成了计算 ∑ s s ! [ x s ] ∏ i ∑ j f i , j x j j ! \sum_s s![x^s]\prod_i\sum_jf_{i,j}\frac{x^j}{j!} ∑ss![xs]∏i∑jfi,jj!xj(这里 f f f加了一维注意一下)。
于是 i i i的生成函数 A i = ∑ j f i , j x j j ! A_i=\sum_j f_{i,j} \frac{x^j}{j!} Ai=∑jfi,jj!xj
然而题目要让我们计算的是环……所以上面的这个是假的。
算环的时候,考虑枚举第一个是什么颜色,假设颜色是 i i i,并且枚举它的长度 l l l。钦定第一段是长度为 l l l的颜色为 i i i的段,计算方案数之后乘 l l l。为了不算重,要保证最后一段和第二段的颜色不是 i i i。
记 i i i为开头时的生成函数为 B i B_i Bi。类似 F F F的定义,定义 G G G,表示颜色为 i i i的求分成 k + 1 k+1 k+1段,其中 1 1 1段为第一段的贡献。显然有 G k = x e x ( e x − 1 ) k [ x c i ] G_k=xe^x(e^x-1)^k[x^{c_i}] Gk=xex(ex−1)k[xci]
类似地用公式 ( ∑ a i ) ! ∏ a i ! \frac{(\sum a_i)!}{\prod a_i!} ∏ai!(∑ai)!来计算,并且同样要容斥。为了保证最后一段和第二段的颜色不是 i i i,所以还要减去这部分的贡献。
于是我们就得到了生成函数 B i B_i Bi
求答案的时候将生成函数乘起来。具体来说,维护二元组 ( A , B ) (A,B) (A,B),相乘时 ( A 0 , B 0 ) ∗ ( A 1 , B 1 ) = ( A 0 A 1 , A 0 B 1 + A 1 B 0 ) (A_0,B_0)*(A_1,B_1)=(A_0A_1,A_0B_1+A_1B_0) (A0,B0)∗(A1,B1)=(A0A1,A0B1+A1B0)
乘的时候用合并果子的方法合并即可。
using namespace std;
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
#define N 524288
#define ll long long
#define mo 998244353
#define mo2 998244353ll*998244353ll
ll fac[N],ifac[N];
ll qpow(ll x,ll y=mo-2){
ll r=1;
for (;y;y>>=1,x=x*x%mo)
if (y&1)
r=r*x%mo;
return r;
}
int nN,re[N];
void setlen(int n){
int bit=0;
for (nN=1;nN<=n;nN<<=1,++bit);
re[0]=0;
for (int i=1;i<nN;++i)
re[i]=re[i>>1]>>1|(i&1)<<bit-1;
}
void clear(int A[]){memset(A,0,sizeof(int)*nN);}
int cop(int A[],vector<int> &a){clear(A);for (int i=0;i<a.size();++i) A[i]=a[i];}
void dft(int A[],int flag){
for (int i=0;i<nN;++i)
if (i<re[i])
swap(A[i],A[re[i]]);
static int wnk[N];
for (int i=1;i<nN;i<<=1){
int wn=qpow(3,flag==1?(mo-1)/(2*i):mo-1-(mo-1)/(2*i));
wnk[0]=1;
for (int k=1;k<i;++k)
wnk[k]=(ll)wnk[k-1]*wn%mo;
for (int j=0;j<nN;j+=i<<1)
for (int k=0;k<i;++k){
ll x=A[j+k],y=(ll)A[j+k+i]*wnk[k];
A[j+k]=(x+y)%mo;
A[j+k+i]=(x-y+mo2)%mo;
}
}
if (flag==-1){
int inv=qpow(nN);
for (int i=0;i<nN;++i)
A[i]=(ll)A[i]*inv%mo;
}
}
struct Pair{vector<int> a,b;};
void multi(vector<int> &c,vector<int> &a,vector<int> &b){
static int A[N],B[N],C[N];
int n=a.size()-1+b.size()-1;
setlen(n);
cop(A,a),cop(B,b);
dft(A,1),dft(B,1);
for (int i=0;i<nN;++i) C[i]=(ll)A[i]*B[i]%mo;
dft(C,-1);
c.clear();
for (int i=0;i<=n;++i)
c.push_back(C[i]);
}
void multi(Pair &z,Pair &x,Pair &y){
static int Ax[N],Bx[N],Ay[N],By[N],Az[N],Bz[N];
int n=x.a.size()-1+y.a.size()-1;
setlen(n);
cop(Ax,x.a),cop(Bx,x.b),cop(Ay,y.a),cop(By,y.b);
dft(Ax,1),dft(Bx,1),dft(Ay,1),dft(By,1);
for (int i=0;i<nN;++i){
Az[i]=(ll)Ax[i]*Ay[i]%mo;
Bz[i]=((ll)Ax[i]*By[i]+(ll)Ay[i]*Bx[i])%mo;
}
dft(Az,-1),dft(Bz,-1);
z.a.clear(),z.b.clear();
for (int i=0;i<=n;++i){
z.a.push_back(Az[i]);
z.b.push_back(Bz[i]);
}
}
int n,m;
int c[N];
vector<int> F[N],G[N];
Pair fg[N];
void getF(int ci,vector<int> &F){
static vector<int> a,b;
int n=ci;
a.clear(),b.clear();
for (int i=0,o=1;i<=n;++i,o=mo-o){
a.push_back((ll)qpow(i,ci)%mo*ifac[i]%mo);
b.push_back(o*ifac[i]%mo);
}
multi(F,a,b);
for (int i=F.size()-1-n;i;--i)
F.pop_back();
for (int i=0;i<=n;++i)
F[i]=(ll)fac[i]%mo*ifac[ci]%mo*F[i]%mo;
}
void getG(int ci,vector<int> &G){
static vector<int> a,b;
int n=ci-1;
a.clear(),b.clear();
for (int i=0,o=1;i<=n;++i,o=mo-o){
a.push_back((ll)qpow(i+1,ci-1)%mo*ifac[i]%mo);
b.push_back(o*ifac[i]%mo);
}
multi(G,a,b);
for (int i=G.size()-1-n;i;--i)
G.pop_back();
for (int i=0;i<=n;++i)
G[i]=(ll)fac[i]%mo*ifac[ci-1]%mo*G[i]%mo;
}
void getrev(vector<int> &f,vector<int> &F){
static vector<int> a,b,c;
int n=F.size()-1;
a.clear(),b.clear();
a.push_back(0);
for (int i=1;i<=n;++i)
a.push_back(fac[i-1]*F[i]%mo);
for (int i=0,o=1;i<n;++i,o=mo-o)
b.push_back(o*ifac[i]%mo);
b.push_back(0);
reverse(b.begin(),b.end());
multi(c,a,b);
f.clear();
f.push_back(0);
for (int i=1;i<=n;++i)
f.push_back(c[i+n]*ifac[i-1]%mo);
}
void adjust(vector<int> &g){
int n=g.size()-1;
for (int i=0;i<=n;++i)
g[i]=((ll)g[i]-(i+1<=n?g[i+1]*2:0)+(i+2<=n?g[i+2]:0)+mo+mo)%mo;
}
Pair *h[N];
int nh;
bool cmph(Pair *son,Pair *fa){return son->a.size()>fa->a.size();}
int main(){
// freopen("in.txt","r",stdin);
freopen("always.in","r",stdin);
freopen("always.out","w",stdout);
scanf("%d",&n);
for (int i=1;i<=n;++i)
scanf("%d",&c[i]),m+=c[i];
fac[0]=1;
for (int i=1;i<=m;++i)
fac[i]=fac[i-1]*i%mo;
ifac[m]=qpow(fac[m]);
for (int i=m-1;i>=0;--i)
ifac[i]=ifac[i+1]*(i+1)%mo;
sort(c+1,c+n+1);
for (int i=1;i<=n;++i){
getF(c[i],F[i]);
getrev(fg[i].a,F[i]);
for (int j=0;j<fg[i].a.size();++j)
fg[i].a[j]=fg[i].a[j]*ifac[j]%mo;
getG(c[i],G[i]);
getrev(fg[i].b,G[i]);
fg[i].b[0]=G[i][0];
adjust(fg[i].b);
for (int j=0;j<fg[i].b.size();++j)
fg[i].b[j]=fg[i].b[j]*ifac[j]%mo;
h[nh++]=&fg[i];
}
make_heap(h,h+nh,cmph);
while (nh>1){
Pair *x,*y;
x=h[0],pop_heap(h,h+nh--,cmph);
y=h[0],pop_heap(h,h+nh--,cmph);
multi(*x,*x,*y);
h[nh++]=x;
y->a.clear(),y->b.clear();
push_heap(h,h+nh,cmph);
}
ll ans=0;
for (int i=0;i<h[0]->b.size();++i)
(ans+=fac[i]*h[0]->b[i])%=mo;
printf("%lld\n",ans);
return 0;
}