Description
给定nnn个正整数aia_iai,令N+1=∑aiN+1=\sum a_iN+1=∑ai
将执行NNN次操作,每次等概率随机选择一个非零的aia_iai并令其减一,显然NNN次操作结束之后有且仅有一个ai=1a_i=1ai=1
对于一开始的nnn个aia_iai,分别求出它们最后为111的概率
n,ai≤30n,a_i \leq 30n,ai≤30
Analysis
写这题时作死没开longlong,结果还是被乘法忘记转longlong坑死了
以后这种题无脑上longlong,千万别玩火
每次选择非零aia_iai减一,每次操作概率会更改,不好统计答案。
考虑将操作转化成,随机选择aia_iai,若ai>0a_i>0ai>0则令其减一,否则继续随机,显然答案等价。这样每次操作的每个aia_iai被选中的概率都是1/n1/n1/n
使用EGF来描述操作序列,为了方便计算,我们把最后一个剩下的数也加入操作序列。不妨假设最后剩下的是a1a_1a1
对于2≤i≤n2\leq i\leq n2≤i≤n,iii至少出现aia_iai次;对于111,111恰好出现a1a_1a1次,且一定在序列末出现
令Gk(x)=∑i=0ak−1xii!G_k(x)=\sum_{i=0}^{a_k-1}\frac{x^i}{i!}Gk(x)=∑i=0ak−1i!xi,那么合法序列(除最后一项)的EGF就是
F(x)=xa1−1(a1−1)!∏i≠1(ex−Gi(x))F(x)=\frac{x^{a_1-1}}{(a_1-1)!}\prod_{i\neq 1}(e^x-G_i(x))F(x)=(a1−1)!xa1−1i̸=1∏(ex−Gi(x))
答案即为
∑i≥0i![xi]F(x)ni+1\sum_{i\geq 0}\frac{i![x^i]F(x)}{n^{i+1}}i≥0∑ni+1i![xi]F(x)
将F(x)F(x)F(x)展开,每部分形如λxdejx\lambda x^d e^{jx}λxdejx,其中λ\lambdaλ可以dpdpdp求出,考虑计算该部分对答案的贡献
λnd+1∑i≥0ji(d+i)!i!ni\frac{\lambda}{n^{d+1}} \sum_{i\geq 0}\frac{j^i (d+i)!}{i!n^i}nd+1λi≥0∑i!niji(d+i)!
=λd!nd+1∑i≥0(jn)i(d+ii)=\frac{\lambda d!}{n^{d+1}} \sum_{i\geq 0}(\frac{j}{n})^i {d+i \choose i}=nd+1λd!i≥0∑(nj)i(id+i)
=λd!nd+1(nn−j)d+1=\frac{\lambda d!}{n^{d+1}} (\frac{n}{n-j})^{d+1}=nd+1λd!(n−jn)d+1
最后还剩下一个问题,如果对于每个aia_iai我们都暴力做dp来求λ\lambdaλ,令m=max{ai}m=max\{a_i\}m=max{ai}复杂度就是O(n4m2)O(n^4 m^2)O(n4m2),会TLE
正确姿势是一开始求出∏i(ex−Gi(x))\prod_{i}(e^x-G_i(x))∏i(ex−Gi(x)),对于每个aia_iai再消除它
复杂度O(n3m2)O(n^3m^2)O(n3m2)
Code
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<assert.h>
#define fo(i,a,b) for(int i=(a);i<=(b);++i)
#define fd(i,b,a) for(int i=(b);i>=(a);--i)
using namespace std;
typedef long long ll;
int read(){int n=0,p=1;char ch;for(ch=getchar();ch<'0' || ch>'9';ch=getchar())if(ch=='-') p=-1;for(;'0'<=ch && ch<='9';ch=getchar()) n=n*10+ch-'0';return n*p;}
const int N=33,mo=998244353;
void inc(int &x,int y){x=(x+y)%mo;}
int qmi(int x,int n)
{
int t=1;
for(x%=mo;n;n>>=1,x=1ll*x*x%mo) if(n&1) t=1ll*t*x%mo;
return t;
}
int n,m,a[N],f[N][N][N*N],g[N][N*N],fac[N*N],ifac[N*N],inv[N*N];
int main()
{
n=read();
fo(i,1,n) a[i]=read(),m=max(m,a[i]);
fac[0]=ifac[0]=inv[1]=1;
fo(i,1,n*m) fac[i]=1ll*fac[i-1]*i%mo;
ifac[n*m]=qmi(fac[n*m],mo-2);
fd(i,n*m-1,1) ifac[i]=1ll*ifac[i+1]*(i+1)%mo,inv[i+1]=1ll*ifac[i+1]*fac[i]%mo;
f[0][0][0]=1;
fo(i,1,n)
fo(j,0,i)
fo(k,0,i*m)
{
if(j) inc(f[i][j][k],f[i-1][j-1][k]);
fo(l,0,min(a[i]-1,k))
inc(f[i][j][k],-1ll*f[i-1][j][k-l]*ifac[l]%mo);
}
fo(i,1,n)
{
memset(g,0,sizeof(g));
fd(j,n-1,0)
fo(k,0,(n-1)*m)
{
g[j][k]=f[n][j+1][k];
fo(l,0,min(a[i]-1,k))
inc(g[j][k],1ll*g[j+1][k-l]*ifac[l]%mo);
}
int ans=0;
fo(j,0,n-1)
fo(k,0,(n-1)*m) if(g[j][k])
{
int lamda=1ll*ifac[a[i]-1]*g[j][k]%mo,d=(a[i]-1)+k;
int t=1ll*lamda*qmi(inv[n-j],d+1)%mo*fac[d]%mo;
ans=(ans+t)%mo;
}
printf("%d ",(ans%mo+mo)%mo);
}
return 0;
}