题面
已知一个 n n n 个点的无权无向图中 1 号点到 i ( 1 ≤ i ≤ n ) i (1\le i\le n) i(1≤i≤n) 号点的最短路 d i d_i di,试求恰好 m m m 条边的方案数。
数据范围: n ≤ 1 0 5 , m ≤ 2 ⋅ 1 0 5 n\le 10^5, m\le 2\cdot 10^5 n≤105,m≤2⋅105,令 t i = ∑ j [ d j = i ] t_i=\sum_j [d_j=i] ti=∑j[dj=i] 则有 ∑ i t i t i + 1 ≤ 2 ⋅ 1 0 5 \sum_i t_it_{i+1}\le 2\cdot 10^5 ∑ititi+1≤2⋅105。
题解
将图变成分层图,边分为两部分,一部分是每层间的边,一部分是相邻两层的边,所以不难知道答案是
[ x m ] ∏ i = 0 D − 1 ( ∑ n ≥ 1 x n ( a i n ) ) a i + 1 ∏ i = 0 n ( ∑ n ≥ 0 x n ( a i ( a i − 1 ) / 2 n ) ) [x^m] \prod_{i=0}^{D-1} \left( \sum_{n\ge 1} x^n \binom{a_i}{n} \right)^{a_{i+1}} \prod_{i=0}^n \left( \sum_{n\ge 0} x^n \binom{a_i(a_i-1)/2}{n} \right) [xm]i=0∏D−1(n≥1∑xn(nai))ai+1i=0∏n(n≥0∑xn(nai(ai−1)/2))
这个生成函数是
( ∏ i = 0 D − 1 ( ( x + 1 ) a i − 1 ) a i + 1 ) ( x + 1 ) ∑ i = 0 D a i ( a i − 1 ) / 2 \left( \prod_{i=0}^{D-1} ((x+1)^{a_i} -1) ^{a_{i+1}} \right) (x+1)^{\sum\limits_{i=0}^D a_i(a_i-1)/2 } (i=0∏D−1((x+1)ai−1)ai+1)(x+1)i=0∑Dai(ai−1)/2
接下来应当注意到题目中奇怪的数据范围 ∑ i a i a i + 1 ≤ 2 ⋅ 1 0 5 \sum\limits_i a_ia_{i+1} \le 2\cdot 10^5 i∑aiai+1≤2⋅105,令其为 S S S,则可以发现左侧的乘积式的长度是 S S S 级别的,所以我猜测这里应该已经能够做了,但我不是太清楚。
言归正传,接下来使用套路:
∏ i = 0 D − 1 ( ( x + 1 ) a i − 1 ) a i + 1 = exp ∑ i = 0 D − 1 a i + 1 ln ( ( x + 1 ) a i − 1 ) \prod_{i=0}^{D-1} ((x+1)^{a_i} -1) ^{a_{i+1}} = \exp \sum_{i=0}^{D-1} a_{i+1} \ln ((x+1)^{a_i} -1) i=0∏D−1((x+1)ai−1)ai+1=expi=0∑D−1ai+1ln((x+1)ai−1)
好吧这是个什么玩意,ln 里边的形式幂级数一次项根本就是 0。
所以重新做一遍,我们做换元 z = x + 1 z=x+1 z=x+1,再做一次:
G ( z ) = ∏ i = 0 D − 1 ( z a i − 1 ) a i + 1 = ( − 1 ) ∑ i = 1 D a i exp ( ∑ i = 0 D − 1 a i + 1 ln ( 1 − z a i ) ) = ( − 1 ) ∑ i = 1 D a i exp ( − ∑ i = 0 D − 1 a i + 1 ∑ j ≥ 1 z j a i j ) G(z)=\prod_{i=0}^{D-1} (z^{a_i} - 1) ^{a_{i+1}} = (-1)^{\sum\limits_{i=1}^D a_i} \exp \left( \sum_{i=0}^{D-1} a_{i+1} \ln (1-z^{a_i}) \right) = (-1)^{\sum\limits_{i=1}^D a_i} \exp \left( - \sum_{i=0}^{D-1} a_{i+1} \sum_{j\ge 1} \frac{z^{ja_i}}{j} \right) G(z)=i=0∏D−1(zai−1)ai+1=(−1)i=1∑Daiexp(i=0∑D−1ai+1ln(1−zai))=(−1)i=1∑Daiexp(−i=0∑D−1ai+1j≥1∑jzjai)
如果我们要求出这个形式幂级数的前 L L L 项,由于 exp 里的形式幂级数的有效项为 O ( ∑ i L i ) = O ( L log L ) O(\sum_i \frac{L}{i})=O(L\log L) O(∑iiL)=O(LlogL),因此求出这个式子的复杂度是 O ( L log L ) O(L\log L) O(LlogL)。
现在我们需要还原换元之前的式子,很明显我们需要知道这个形式幂级数(多项式)所有的信息,所以求出 [ z 0 ∼ S ] G ( z ) [z^{0\sim S}]G(z) [z0∼S]G(z)。
F ( x ) = G ( x + 1 ) = ∑ i ≥ 0 g i ( x + 1 ) i = ∑ i ≥ 0 g i ∑ j ≥ 0 ( i j ) x j = ∑ j ≥ 0 x j ∑ i ≥ 0 g i ( i j ) = ∑ j ≥ 0 x j ∑ i ≥ 0 g i + j ( i + j ) ! i ! j ! = ∑ j ≥ 0 x j j ! ∑ i ≥ 0 h S − i − j i ! \begin{aligned} F(x) = G(x+1) = \sum_{i\ge 0} g_i (x+1)^i = \sum_{i\ge 0} g_i \sum_{j\ge 0} \binom{i}{j} x^j = \sum_{j\ge 0} x^j \sum_{i\ge 0} g_i \binom{i}{j} \\ = \sum_{j\ge 0} x^j \sum_{i\ge 0} g_{i+j} \frac{(i+j)!}{i!j!} = \sum_{j\ge 0} \frac{x^j}{j!} \sum_{i\ge 0} \frac{h_{S-i-j}}{i!} \end{aligned} F(x)=G(x+1)=i≥0∑gi(x+1)i=i≥0∑gij≥0∑(ji)xj=j≥0∑xji≥0∑gi(ji)=j≥0∑xji≥0∑gi+ji!j!(i+j)!=j≥0∑j!xji≥0∑i!hS−i−j
其中 h i = g S − i ( S − i ) ! h_i=g_{S-i} (S-i)! hi=gS−i(S−i)!,所以这里也 O ( S log S ) O(S\log S) O(SlogS) 解决。
现在要说刚才被忽略的右半部分,理性分析一下发现项数不多所以我们维护下降幂就好了。
时间复杂度 O ( n + S log S + m log m ) O(n + S\log S + m\log m) O(n+SlogS+mlogm)。
代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
#define ri register int
const int MAXN = 1000005, MOD = 998244353, Gen = 3;
ll q_pow(ll a, ll b, ll p = MOD) {
ll ret = 1;
for (; b; a = a * a % p, b >>= 1) if (b & 1) ret = ret * a % p;
return ret;
}
ll inv(ll x, ll p = MOD) { return q_pow(x, p - 2, p); }
inline int add(int x, int y) { return x + y > MOD ? x + y - MOD : x + y; }
inline int dec(int x, int y) { return x - y < 0 ? x - y + MOD : x - y; }
inline void Add(int& x, int y) { x += y; if (x >= MOD) x -= MOD; }
inline void Dec(int& x, int y) { x -= y; if (x < 0) x += MOD; }
int rev[21][MAXN], fac[MAXN], ifac[MAXN], I[MAXN];
int rt[MAXN], irt[MAXN];
void Dft(int* A, int LIM, int L) {
for (ri i = 0; i < LIM; ++i) if (i < rev[L][i]) swap(A[i], A[rev[L][i]]);
for (ri l = 2; l <= LIM; l <<= 1) {
for (ri i = 0; i < LIM; i += l) {
for (ri j = 0; j < (l >> 1); ++j) {
ll x = A[j | i], y = (ll)A[j | i | (l >> 1)] * rt[l | j] % MOD;
A[j | i] = add(x, y), A[j | i | (l >> 1)] = dec(x, y);
}
}
}
}
void iDft(int* A, int LIM, int L) {
for (ri i = 0; i < LIM; ++i) if (i < rev[L][i]) swap(A[i], A[rev[L][i]]);
for (ri l = 2; l <= LIM; l <<= 1) {
for (ri i = 0; i < LIM; i += l) {
for (ri j = 0; j < (l >> 1); ++j) {
ll x = A[j | i], y = (ll)A[j | i | (l >> 1)] * irt[l | j] % MOD;
A[j | i] = add(x, y), A[j | i | (l >> 1)] = dec(x, y);
}
}
}
int invLIM = inv(LIM);
for (ri i = 0; i < LIM; ++i) A[i] = 1ll * A[i] * invLIM % MOD;
}
void Multiply(int* A, int* B, int LIM, int L) {
Dft(A, LIM, L), Dft(B, LIM, L);
for (ri i = 0; i < LIM; ++i) A[i] = 1ll * A[i] * B[i] % MOD;
iDft(A, LIM, L);
}
int Tinv[MAXN];
void Inverse(const int* F, int* G, int LIM, int L) {
G[0] = inv(F[0]);
for (ri lim = 2, l = 1; lim <= LIM; lim <<= 1, ++l) {
for (ri i = 0; i < lim; ++i) Tinv[i] = F[i];
for (ri i = lim; i < (lim << 1); ++i) Tinv[i] = 0;
Dft(Tinv, lim << 1, l + 1), Dft(G, lim << 1, l + 1);
for (ri i = 0; i < (lim << 1); ++i) G[i] = dec(add(G[i], G[i]), 1ll * Tinv[i] * G[i] % MOD * G[i] % MOD);
iDft(G, lim << 1, l + 1);
for (ri i = lim; i < (lim << 1); ++i) G[i] = 0;
}
}
void Derivative(int* F, int LIM) {
for (ri i = 0; i < LIM - 1; ++i) F[i] = 1ll * F[i + 1] * (i + 1) % MOD;
F[LIM - 1] = 0;
}
void Inter(int* F, int LIM) {
for (ri i = LIM - 1; i >= 1; --i) F[i] = 1ll * F[i - 1] * I[i] % MOD;
F[0] = 0;
}
int Tln[MAXN];
void Ln(int* F, int LIM, int L) {
Inverse(F, Tln, LIM, L);
Derivative(F, LIM);
Multiply(F, Tln, LIM << 1, L + 1);
Inter(F, LIM);
for (ri i = 0; i < (LIM << 1); ++i) Tln[i] = 0;
}
int Texp[MAXN];
void Exp(const int* F, int* G, int LIM, int L) {
G[0] = 1;
for (ri lim = 2, l = 1; lim <= LIM; lim <<= 1, ++l) {
for (ri i = 0; i < lim; ++i) Texp[i] = G[i];
for (ri i = lim; i < (lim << 1); ++i) Texp[i] = 0;
Ln(G, lim, l);
for (ri i = 0; i < lim; ++i) G[i] = dec(F[i], G[i]);
G[0] = add(G[0], 1);
Multiply(G, Texp, lim << 1, l + 1);
for (ri i = lim; i < (lim << 1); ++i) G[i] = 0;
}
}
int N, M, LIM = 1, L, a[MAXN], b[MAXN], D, S;
int F[MAXN], G[MAXN], H[MAXN];
int main() {
scanf("%d%d", &N, &M); int t;
for (int i = 1; i <= N; ++i) scanf("%d", &t), ++a[t], D = max(t, D);
for (int i = 0; i < D; ++i) S += a[i] * a[i + 1];
while (LIM <= max(M, S) + 2) {
LIM <<= 1, ++L;
for (ri i = 0; i < LIM; ++i) rev[L][i] = (rev[L][i >> 1] >> 1) | ((i & 1) << (L - 1));
irt[LIM] = rt[LIM] = 1; int Wn = q_pow(Gen, (MOD - 1) / LIM), iWn = inv(Wn);
for (ri i = 1; i < (LIM >> 1); ++i) rt[i | LIM] = 1ll * rt[(i - 1) | LIM] * Wn % MOD, irt[i | LIM] = 1ll * irt[(i - 1) | LIM] * iWn % MOD;
}
LIM <<= 1, ++L;
for (ri i = 0; i < LIM; ++i) rev[L][i] = (rev[L][i >> 1] >> 1) | ((i & 1) << (L - 1));
irt[LIM] = rt[LIM] = 1; int Wn = q_pow(Gen, (MOD - 1) / LIM), iWn = inv(Wn);
for (ri i = 1; i < (LIM >> 1); ++i) rt[i | LIM] = 1ll * rt[i - 1 | LIM] * Wn % MOD, irt[i | LIM] = 1ll * irt[i - 1 | LIM] * iWn % MOD;
LIM >>= 1, --L;
fac[0] = ifac[0] = fac[1] = ifac[1] = I[1] = 1;
for (ri i = 2; i <= LIM; ++i) {
fac[i] = 1ll * fac[i - 1] * i % MOD, I[i] = 1ll * (MOD - MOD / i) * I[MOD % i] % MOD, ifac[i] = 1ll * ifac[i - 1] * I[i] % MOD;
} for (int i = 0; i < D; ++i) Dec(b[a[i]], a[i + 1]);
for (int i = 0; i < LIM; ++i) if (b[i]) {
for (int j = 1; j * i < LIM; ++j) Add(F[j * i], 1ll * b[i] * I[j] % MOD);
}
Exp(F, G, LIM, L);
if (N % 2 == 0) {
for (int i = 0; i < LIM; ++i) G[i] = dec(0, G[i]);
}
for (int i = 0; i < LIM; ++i) H[i] = F[i] = 0;
for (int i = 0; i <= S; ++i) H[S - i] = 1ll * G[i] * fac[i] % MOD;
for (int i = 0; i <= S; ++i) F[i] = ifac[i];
Multiply(H, F, LIM << 1, L + 1);
for (int i = 0; i <= S; ++i) F[i] = 1ll * H[S - i] * ifac[i] % MOD;
for (int i = S + 1; i < LIM; ++i) F[i] = 0;
int sum = 1, tmp = 0;
for (int i = 0; i <= D; ++i) Add(tmp, 1ll * a[i] * dec(a[i], 1) % MOD * I[2] % MOD);
for (int i = 0; i < LIM; ++i) G[i] = 0;
for (int i = 0; i <= max(M, S); ++i) G[i] = 1ll * ifac[i] * sum % MOD, sum = 1ll * sum * tmp % MOD, Dec(tmp, 1);
Multiply(F, G, LIM << 1, L + 1);
printf("%d\n", F[M]);
return 0;
}