题意
问所有
n
n
n个点的带标号无根森林的树个数的
k
k
k次方和,对
998244353
998244353
998244353取模。
n
≤
20000
,
k
≤
10
n\le20000,k\le10
n≤20000,k≤10
分析
根据prufer序列可以得到
n
n
n个点的生成树数量为
n
n
−
2
n^{n-2}
nn−2,设其指数型生成函数
A
(
x
)
=
∑
i
>
0
i
i
−
2
i
!
A(x)=\sum\limits_{i>0}\frac{i^{i-2}}{i!}
A(x)=i>0∑i!ii−2
设答案为
B
k
(
x
)
B_k(x)
Bk(x),枚举连通块数量不难得到
B
k
(
x
)
=
∑
i
≥
0
A
i
(
x
)
i
!
∗
i
k
B_k(x)=\sum_{i\ge0}\frac{A^i(x)}{i!}*i^k
Bk(x)=i≥0∑i!Ai(x)∗ik
以为连通块之间是有标号的所以要除以一个
i
!
i!
i!。
特别的当
k
=
0
k=0
k=0时有
B
0
(
x
)
=
∑
i
≥
0
A
i
(
x
)
i
!
=
e
A
(
x
)
B_0(x)=\sum_{i\ge0}\frac{A^i(x)}{i!}=e^{A(x)}
B0(x)=i≥0∑i!Ai(x)=eA(x)
两边同时取对数后求导可以得到
B
0
′
(
x
)
=
B
0
(
x
)
A
′
(
x
)
B_0'(x)=B_0(x)A'(x)
B0′(x)=B0(x)A′(x)
将式子两边写成卷积形式可以发现
∑
i
+
j
=
n
−
1
B
0
(
x
)
[
x
i
]
∗
A
′
(
x
)
[
x
j
]
=
B
0
′
(
x
)
[
x
n
−
1
]
=
n
∗
B
(
x
)
[
x
n
]
\sum_{i+j=n-1}B_0(x)[x^i]*A'(x)[x^j]=B'_{0}(x)[x^{n-1}]=n*B(x)[x^n]
i+j=n−1∑B0(x)[xi]∗A′(x)[xj]=B0′(x)[xn−1]=n∗B(x)[xn]
于是我们可以用分治FFT来计算
B
0
(
x
)
B_0(x)
B0(x),当然用多项式exp来求也可以。
又注意到有
B
k
−
1
′
(
x
)
=
A
′
(
x
)
∑
i
≥
0
A
i
−
1
(
x
)
i
!
∗
i
k
B_{k-1}'(x)=A'(x)\sum_{i\ge0}\frac{A^{i-1}(x)}{i!}*i^k
Bk−1′(x)=A′(x)i≥0∑i!Ai−1(x)∗ik
所以有
B
k
−
1
′
(
x
)
A
(
x
)
=
A
′
(
x
)
B
k
(
x
)
B'_{k-1}(x)A(x)=A'(x)B_k(x)
Bk−1′(x)A(x)=A′(x)Bk(x)
由于
B
0
(
x
)
B_0(x)
B0(x)已被求出,我们只要每次对
B
k
−
1
(
x
)
B_{k-1}(x)
Bk−1(x)求导然后乘上
A
(
x
)
A
′
(
x
)
\frac{A(x)}{A'(x)}
A′(x)A(x)即可求出
B
k
(
x
)
B_k(x)
Bk(x)。
时间复杂度
O
(
n
l
o
g
n
(
k
+
l
o
g
n
)
)
O(nlogn(k+logn))
O(nlogn(k+logn))。
代码
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
typedef long long LL;
const int N=20005;
const int MOD=998244353;
int n,m,rev[N*8],A[N*8],inv[N*8],B[N*8],L,tmp[N*8],jc[N],ny[N],tp[N*8],C[N*8];
int ksm(int x,int y)
{
int ans=1;
while (y)
{
if (y&1) ans=(LL)ans*x%MOD;
x=(LL)x*x%MOD;y>>=1;
}
return ans;
}
void NTT(int *a,int f)
{
for (int i=0;i<L;i++) if (i<rev[i]) std::swap(a[i],a[rev[i]]);
for (int i=1;i<L;i<<=1)
{
int wn=ksm(3,f==1?(MOD-1)/i/2:MOD-1-(MOD-1)/i/2);
for (int j=0;j<L;j+=(i<<1))
{
int w=1;
for (int k=0;k<i;k++)
{
int u=a[j+k],v=(LL)w*a[j+k+i]%MOD;
a[j+k]=(u+v)%MOD;a[j+k+i]=(u+MOD-v)%MOD;
w=(LL)w*wn%MOD;
}
}
}
if (f==-1)
{
int k=ksm(L,MOD-2);
for (int i=0;i<L;i++) a[i]=(LL)a[i]*k%MOD;
}
}
void get_inv(int *a,int n)
{
if (n==1) {inv[0]=ksm(a[0],MOD-2);return;}
get_inv(a,n>>1);
int lg=0;
for (L=1;L<=n*2;L<<=1,lg++);
for (int i=0;i<L;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(lg-1));
for (int i=0;i<n;i++) tmp[i]=a[i];
for (int i=n;i<L;i++) tmp[i]=0;
NTT(inv,1);NTT(tmp,1);
for (int i=0;i<L;i++) inv[i]=(inv[i]*2%MOD+MOD-(LL)tmp[i]*inv[i]%MOD*inv[i]%MOD)%MOD;
NTT(inv,-1);
for (int i=n;i<L;i++) inv[i]=0;
}
void solve(int l,int r)
{
if (l==r)
{
if (l) B[l]=(LL)B[l]*ksm(l,MOD-2)%MOD;
return;
}
int mid=(l+r)/2;
solve(l,mid);
int lg=0,len=(mid-l+1)*2;
for (L=1;L<=len*2;L<<=1,lg++);
for (int i=0;i<L;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(lg-1));
for (int i=0;i<len/2;i++) tmp[i]=B[l+i];
for (int i=len/2;i<L;i++) tmp[i]=0;
for (int i=0;i<len;i++) tp[i]=C[i];
for (int i=len;i<L;i++) tp[i]=0;
NTT(tmp,1);NTT(tp,1);
for (int i=0;i<L;i++) tmp[i]=(LL)tmp[i]*tp[i]%MOD;
NTT(tmp,-1);
for (int i=0;i<L;i++) if (i+l+1>mid&&i+l+1<=r) (B[i+l+1]+=tmp[i])%=MOD;
solve(mid+1,r);
}
int main()
{
scanf("%d%d",&n,&m);
jc[0]=jc[1]=ny[0]=ny[1]=1;
for (int i=2;i<=n;i++) jc[i]=(LL)jc[i-1]*i%MOD,ny[i]=(LL)(MOD-MOD/i)*ny[MOD%i]%MOD;
for (int i=2;i<=n;i++) ny[i]=(LL)ny[i-1]*ny[i]%MOD;
A[1]=1;
for (int i=2;i<=n;i++) A[i]=(LL)ksm(i,i-2)*ny[i]%MOD;
for (int i=0;i<=n;i++) C[i]=(LL)(i+1)*A[i+1]%MOD;
for (L=1;L<=n;L<<=1);
get_inv(C,L);
int lg=0;
for (L=1;L<=n*2;L<<=1,lg++);
for (int i=0;i<L;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(lg-1));
NTT(A,1);NTT(inv,1);
for (int i=0;i<L;i++) inv[i]=(LL)inv[i]*A[i]%MOD;
NTT(A,-1);NTT(inv,-1);
for (int i=n+1;i<L;i++) inv[i]=0;
B[0]=1;
solve(0,n);
lg=0;
for (L=1;L<=n*2;L<<=1,lg++);
for (int i=0;i<L;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(lg-1));
NTT(inv,1);
for (int k=1;k<=m;k++)
{
for (int i=0;i<=n;i++) B[i]=(LL)B[i+1]*(i+1)%MOD;
NTT(B,1);
for (int i=0;i<L;i++) B[i]=(LL)B[i]*inv[i]%MOD;
NTT(B,-1);
for (int i=n+1;i<L;i++) B[i]=0;
}
printf("%d\n",(LL)B[n]*jc[n]%MOD);
return 0;
}