Description
给出一个长度为
n
的数组
Data Constraint
2≤n≤5000 0≤ai,k≤109
Solution
首先通过归纳可以得到题目让我们求的是
那也就是让我们求
将选择 [1,n] 中的每个数的概率看成 1 ,最后答案再乘上
接下来我们考虑生成函数 EGF 。
考虑 ai 它的生成函数 Fi(x)
那总的答案的生成函数
F(x)
=
Πni=1Fi(x)
=
enxΠni=1(ai−x)
那答案就是让我们求
[xk]n!nkF(x)
,也就是
[xk]n!nkenxΠni=1(ai−x)
O(n2)
暴力卷积求出
Πni=1(ai−x)
,或者打个分治
NTT O(n log2 n)
也行,记得到的多项式为
G(x)
。
那答案就是求
[xk]n!nkenxG(x)
考虑到
G(x)
只有
n+1
项,于是答案可以写成
随便算算就好,时间复杂度 O(n2) 或 O(n log2 n)
Code
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#define fo(i,j,l) for(int i=j;i<=l;++i)
#define fd(i,j,l) for(int i=j;i>=l;--i)
using namespace std;
typedef long long ll;
const ll N=34e4,mo=998244353,ZD=262144;
ll a[N],f[N],g[N],aa[N],bb[N],cc[N];
ll jc[N],w[N];
int n,bits[N];
ll k;
inline ll ksm(ll o,ll t)
{
ll y=1;
for(;t;t>>=1,o=o*o%mo)
if(t&1)y=y*o%mo;
return y;
}
inline int read()
{
int o=0; char ch=' ';
for(;ch<'0'||ch>'9';ch=getchar());
for(;ch>='0'&&ch<='9';ch=getchar())o=o*10+ch-48;
return o;
}
inline ll mod(ll o)
{return o<0?o+mo:(o>=mo?o-mo:o);}
void prepare()
{
w[0]=1; w[1]=ksm(3,(mo-1)/ZD);
fo(i,2,ZD)w[i]=w[i-1]*w[1]%mo;
}
inline void dft(ll *c,int mm,int sig)
{
ll ww,v;
fo(i,1,mm-1)if(bits[i]<i)swap(c[bits[i]],c[i]);
for(int m=2;m<=mm;m<<=1){
int half=m>>1,U=ZD/m;
fo(i,0,half-1){
ww=sig==1?w[U*i]:w[ZD-U*i];
for(int j=i;j<mm;j+=m){
v=c[j+half]*ww%mo;
c[j+half]=(c[j]-v+mo)%mo;
c[j]=(c[j]+v)%mo;
}
}
}
if(sig==-1){
ll ny=ksm(mm,mo-2);
fo(i,0,mm-1)c[i]=c[i]*ny%mo;
}
}
void divi(int l,int r)
{
if(l+100>=r){
g[0]=1;
fo(i,1,r-l+1)g[i]=0;
fo(i,l,r){
fd(j,i-l+1,1)g[j]=(g[j]*a[i]-g[j-1]+mo)%mo;
g[0]=g[0]*a[i]%mo;
}
fo(i,l,r)f[i]=g[i-l+1];
return;
}
int mid=l+r>>1;
divi(l,mid); divi(mid+1,r);
int ss=0,mm=1;
while(mm<=r-l+2)mm<<=1,++ss;
fo(i,l,mid)aa[i-l+1]=f[i];
aa[0]=1;
fo(i,l,mid)aa[0]=aa[0]*a[i]%mo;
fo(i,mid-l+2,mm)aa[i]=0;
fo(i,mid+1,r)bb[i-mid]=f[i];
bb[0]=1;
fo(i,mid+1,r)bb[0]=bb[0]*a[i]%mo;
fo(i,r-mid+1,mm)bb[i]=0;
fo(i,0,mm-1)bits[i]=(bits[i>>1]>>1)|((i&1)<<ss-1);
dft(aa,mm,1);
dft(bb,mm,1);
fo(i,0,mm-1)aa[i]=aa[i]*bb[i]%mo;
dft(aa,mm,-1);
fo(i,l,r)f[i]=aa[i-l+1];
}
int main()
{
cin>>n>>k;
ll lj=ksm(n,(mo-2)*k);
fo(i,1,n)a[i]=read();
prepare();
divi(1,n); f[0]=1;
fo(i,1,n)f[0]=f[0]*a[i]%mo;
ll dq=0,ans=0;
jc[0]=k;
fo(i,1,n)jc[i]=jc[i-1]*(k-i)%mo;
if(n<k)dq=ksm(n,k-n-1);
fd(i,n,0){
if(i==k)dq=1;else dq=dq*n%mo;
ll dd=f[i]*dq%mo;
if(i!=0)dd=dd*jc[i-1]%mo;
ans=(ans+dd)%mo;
}
ans=ans*lj%mo;
ll js=1;
fo(i,1,n)js=js*a[i]%mo;
ans=(js-ans+mo)%mo;
cout<<ans;
}