题目:http://ifrog.cc/acm/problem/1035
题意:
对所有
n
个点的带标号无根森林,求树个数的
n≤20000,k≤10
。
题解:
令
a(n)
表示
n
点完全图的生成树个数,结合树的 prüfer 序列可知
为了进行有标号的计数,定义
A(x)=∑i≥0a(i)i!xi
表示
n
点带标号连通图生成树的指数型生成函数,定义
考虑
A(x)
与
Bk(x)
的关系,不妨枚举连通块的数量为
i
,对应的森林数量应该是
其中
B0(x)=∑i≥0Ai(x)i!
,通过观察
ex
的泰勒展开或是利用级数求和的技巧,可以得到
B0(x)=eA(x)
,也即
B0(x)
的常数项为
1
,且
对于
k>0
的情况,可以发现
B′k−1(x)=A′(x)∑i≥0ikAi−1(x)i!
,所以有
A′(x)Bk(x)=A(x)B′k−1(x)
,也即
Bk(x)=B′k−1(x)A(x)A′(x)
。
通过上面的推导已经可以计算
B0(x),B1(x),⋯,Bk(x)
的
xi
系数了,其中
i=0,1,⋯,n
。令
f(x)[xn]
表示
f(x)
的
xn
的系数,则具体做法如下:
- 观察 A′(x)B0(x)=B′0(x) 等式两边 xn−1(n>0) 的系数,有 ∑i+j=n−1B0(x)[xi]⋅A′(x)[xj]=B′0(x)[xn−1]=n⋅B0(x)[xn] ,利用分治计算即可。
- 利用多项式求逆预处理 A(x)A′(x) ,计算出 Br−1(x) 后求导再乘以 A(x)A′(x) 即可得到 Br(x) ,其中 r=1,2,⋯,k 。
实际上
Bk(x)=eA(x)∑ki=0Stiring2(k,i)Ai(x)
,也可以用类似的方法直接求解。
时间复杂度
O(nlogn(k+logn))
,贴一下我的 NTT 的模板吧。
代码:
#include <stdio.h>
#include <cstring>
#include <algorithm>
typedef long long LL;
const int maxn = 20001, maxk = 11, maxlen = 16, maxm = 1 << maxlen, mod = 998244353, gen = 3;
int w[maxm], inv2[maxlen + 1];
inline int mod_pow(int x, int k)
{
int ret = 1;
for( ; k; k >>= 1, x = (LL)x * x % mod)
if(k & 1)
ret = (LL)ret * x % mod;
return ret;
}
inline int mod_add(int x, int y)
{
return (x += y) < mod ? x : x - mod;
}
inline int mod_sub(int x, int y)
{
return (x -= y) < 0 ? x + mod : x;
}
inline void NTT(int len, int x[], int flag)
{
static int hisLen = -1, bitLen, bitrev[maxm];
for(bitLen = 0; 1 << bitLen < len; ++bitLen);
if(hisLen != bitLen)
{
for(int i = 1; i < len; ++i)
bitrev[i] = (bitrev[i >> 1] >> 1) | ((i & 1) << (bitLen - 1));
hisLen = bitLen;
}
for(int i = 1; i < len; ++i)
if(i < bitrev[i])
std::swap(x[i], x[bitrev[i]]);
for(int i = 1, d = 1; d < len; ++i, d <<= 1)
for(int j = 0; j < len; j += d << 1)
for(int k = 0, *X = x + j; k < d; ++k)
{
int t = (LL)w[k << (maxlen - i)] * X[k + d] % mod;
X[d + k] = mod_sub(X[k], t);
X[k] = mod_add(X[k], t);
}
if(flag == -1)
{
std::reverse(x + 1, x + len);
int val = inv2[bitLen];
for(int i = 0; i < len; ++i)
x[i] = (LL)x[i] * val % mod;
}
}
void PolyInv(int n, int cur[], int nxt[maxm])
{
// EFFECTS: NXT = CUR^(-1), NXT[len/2 : len] is any number, len >= n * 2
int len;
static int tmp[maxm];
// nxt = CUR^(-1) (mod x)
nxt[0] = 1; //mod_inv(cur[0]);
for(len = 2; (len >> 1) < n; len <<= 1)
{
// before here, nxt = CUR^(-1) (mod x^(len/2)), nxt[len/2 : len] is any number
int lim = std::min(n, len);
// tmp = CUR (mod x^min(n,len)) -> tmp (mod x^(len*2))
memcpy(tmp, cur, lim * sizeof(int));
memset(tmp + lim, 0, ((len << 1) - lim) * sizeof(int));
NTT(len << 1, tmp, 1);
// nxt (mod x^(len/2)) -> nxt (mod x^(len*2))
memset(nxt + (len >> 1), 0, ((len << 1) - (len >> 1)) * sizeof(int));
NTT(len << 1, nxt, 1);
// nxt = (2 - CUR * nxt) * nxt (mod x^len), nxt[len : len * 2] is any number
for(int i = 0; i < len << 1; ++i)
if((nxt[i] = (2 - (LL)tmp[i] * nxt[i]) % mod * nxt[i] % mod) < 0)
nxt[i] += mod;
NTT(len << 1, nxt, -1);
}
}
int n, kk, inv[maxn], A[maxm | 1], f[maxk][maxm | 1], *B = f[0], *C = f[1], *D = f[2], *E = f[3];
void cdq(int L, int R) // B C = B'
{
if(L == R)
{
B[L] = L ? (LL)B[L] * inv[L] % mod : 1;
return;
}
int M = (L + R) >> 1;
cdq(L, M);
int len, plen = R - L, qlen = M - L + 1;
for(len = 1; len < plen; len <<= 1);
memcpy(D, B + L, qlen * sizeof(int));
memset(D + qlen, 0, (len - qlen) * sizeof(int));
NTT(len, D, 1);
memcpy(E, C, plen * sizeof(int));
memset(E + plen, 0, (len - plen) * sizeof(int));
NTT(len, E, 1);
for(int i = 0; i < len; ++i)
D[i] = (LL)D[i] * E[i] % mod;
NTT(len, D, -1);
for(int i = M + 1, *_D = D - L - 1; i <= R; ++i)
if((B[i] += _D[i]) >= mod)
B[i] -= mod;
cdq(M + 1, R);
}
int main()
{
w[0] = 1;
w[1] = mod_pow(gen, (mod - 1) >> maxlen); // make sure that mod = 2 ^ maxlen * k + 1
for(int i = 2; i < maxm; ++i)
w[i] = (LL)w[i - 1] * w[1] % mod;
inv2[0] = 1;
inv2[1] = (mod + 1) >> 1;
for(int i = 2; i <= maxlen; ++i)
inv2[i] = (LL)inv2[i - 1] * inv2[1] % mod;
scanf("%d%d", &n, &kk);
inv[1] = 1;
for(int i = 2; i <= n; ++i)
inv[i] = mod - (int)(mod / i * (LL)inv[mod % i] % mod);
int len;
for(len = 1; len <= n << 1; len <<= 1);
A[1] = C[0] = 1;
for(int i = 2, iact = 1; i <= n; ++i)
{
iact = (LL)iact * inv[i] % mod;
A[i] = (LL)mod_pow(i, i - 2) * iact % mod;
C[i - 1] = (LL)A[i] * i % mod;
}
cdq(0, n);
PolyInv(n, C, D);
memset(D + n, 0, (len - n) * sizeof(int));
NTT(len, A, 1);
NTT(len, D, 1);
for(int i = 0; i < len; ++i)
A[i] = (LL)A[i] * D[i] % mod;
NTT(len, A, -1);
memset(A + n + 1, 0, (len - n - 1) * sizeof(int));
NTT(len, A, 1);
for(int i = 1; i <= kk; ++i)
{
for(int j = 1; j <= n; ++j)
f[i][j - 1] = (LL)f[i - 1][j] * j % mod;
memset(f[i] + n, 0, (len - n) * sizeof(int));
NTT(len, f[i], 1);
for(int j = 0; j < len; ++j)
f[i][j] = (LL)f[i][j] * A[j] % mod;
NTT(len, f[i], -1);
}
for(int i = 2; i <= n; ++i)
f[kk][n] = (LL)f[kk][n] * i % mod;
printf("%d\n", f[kk][n]);
return 0;
}