Address
https://www.lydsy.com/JudgeOnline/problem.php?id=5306
Solution
先回顾一下「
{1,2,...,M}
{
1
,
2
,
.
.
.
,
M
}
个中恰好
K
K
个合法」的容斥求法:
接下来考虑一个问题:在 M M 种颜色中选出 种,每种颜色染 S S 个格子,剩下的 个格子分别染剩下的 M−i M − i 种颜色之一,求方案数。
(1)选出 i i 种颜色,方案数:
(2)每种颜色分别染 S S 格:
(3)染剩下 N−iS N − i S 格:
(M−i)N−iS
(
M
−
i
)
N
−
i
S
所以,方案数:
N!(S!)i(N−iS)!CiM(M−i)N−iS
N
!
(
S
!
)
i
(
N
−
i
S
)
!
C
M
i
(
M
−
i
)
N
−
i
S
所以,答案为:
∑k=0MWk∑i=kMN!(S!)i(N−iS)!(−1)i−kCkiCiM(M−i)N−iS
∑
k
=
0
M
W
k
∑
i
=
k
M
N
!
(
S
!
)
i
(
N
−
i
S
)
!
(
−
1
)
i
−
k
C
i
k
C
M
i
(
M
−
i
)
N
−
i
S
复杂度显然是平方的。
但根据式 (mo) 子 (shu) ,容易想到把原式转化为卷积形式。
考虑把 ∑Mi=k ∑ i = k M 后的式子改成枚举 i i 满足 为合法颜色数:
∑k=0MWk∑i=0M−kN!(S!)i+k(N−(i+k)S)!(−1)iCki+kCi+kM(M−i−k)N−(i+k)S
∑
k
=
0
M
W
k
∑
i
=
0
M
−
k
N
!
(
S
!
)
i
+
k
(
N
−
(
i
+
k
)
S
)
!
(
−
1
)
i
C
i
+
k
k
C
M
i
+
k
(
M
−
i
−
k
)
N
−
(
i
+
k
)
S
=N!∑k=0MWk∑i=0M−k(−1)i×(i+k)!i!k!×M!(i+k)!(M−i−k)!(M−i−k)N−(i+k)S(S!)i+k(N−(i+k)S)!
=
N
!
∑
k
=
0
M
W
k
∑
i
=
0
M
−
k
(
−
1
)
i
×
(
i
+
k
)
!
i
!
k
!
×
M
!
(
i
+
k
)
!
(
M
−
i
−
k
)
!
(
M
−
i
−
k
)
N
−
(
i
+
k
)
S
(
S
!
)
i
+
k
(
N
−
(
i
+
k
)
S
)
!
=N!M!∑k=0MMkk!(S!)k∑i=0M−k(−1)ii!(S!)i×(M−i−k)N−(i+k)S(M−i−k)!(N−(i+k)S)!
=
N
!
M
!
∑
k
=
0
M
M
k
k
!
(
S
!
)
k
∑
i
=
0
M
−
k
(
−
1
)
i
i
!
(
S
!
)
i
×
(
M
−
i
−
k
)
N
−
(
i
+
k
)
S
(
M
−
i
−
k
)
!
(
N
−
(
i
+
k
)
S
)
!
=N!M!∑k=0MMkk!(S!)k∑i=0M−k(−1)ii!(S!)i×(M−k−i)N+(M−k−i−M)S(M−k−i)!(N+(M−k−i−M)S)!
=
N
!
M
!
∑
k
=
0
M
M
k
k
!
(
S
!
)
k
∑
i
=
0
M
−
k
(
−
1
)
i
i
!
(
S
!
)
i
×
(
M
−
k
−
i
)
N
+
(
M
−
k
−
i
−
M
)
S
(
M
−
k
−
i
)
!
(
N
+
(
M
−
k
−
i
−
M
)
S
)
!
设:
F(i)=(−1)ii!(S!)i
F
(
i
)
=
(
−
1
)
i
i
!
(
S
!
)
i
G(i)=iN+(i−M)Si!(N+(i−M)S)!
G
(
i
)
=
i
N
+
(
i
−
M
)
S
i
!
(
N
+
(
i
−
M
)
S
)
!
H(i)=Mii!(S!)i
H
(
i
)
=
M
i
i
!
(
S
!
)
i
那么就变成了:
N!M!∑k=0MH(k)(F⨂G)(M−k)
N
!
M
!
∑
k
=
0
M
H
(
k
)
(
F
⨂
G
)
(
M
−
k
)
是一个卷积的形式,可以使用 NTT 计算出。
复杂度 O(n+mlogm) O ( n + m log m ) 。
Code
#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#define For(i, a, b) for (i = a; i <= b; i++)
#define Step(i, a, b, x) for (i = a; i <= b; i += x)
#define Pow(k, n) for (k = 1; k < n; k <<= 1)
using namespace std;
inline int read() {
int res = 0; bool bo = 0; char c;
while (((c = getchar()) < '0' || c > '9') && c != '-');
if (c == '-') bo = 1; else res = c - 48;
while ((c = getchar()) >= '0' && c <= '9')
res = (res << 3) + (res << 1) + (c - 48);
return bo ? ~res + 1 : res;
}
const int N = 3e5 + 5, M = 1e7 + 5, ZZQ = 1004535809;
int n, m, s, w[N], f[N], g[N], h[N], fac[M], inv[M],
spw[N], rev[N], ff = 1, gg, tot, gp[N], res[N], ans;
int qpow(int a, int b) {
int res = 1;
while (b) {
if (b & 1) res = 1ll * res * a % ZZQ;
a = 1ll * a * a % ZZQ;
b >>= 1;
}
return res;
}
void FFT(int n, int *a, int op) {
int i, j, k, sp = n;
gp[n] = qpow(op == 1 ? 3 : 334845270, (ZZQ - 1) / n);
For (i, 0, n - 1) if (i < rev[i]) swap(a[i], a[rev[i]]);
For (i, 1, tot) sp >>= 1,
gp[sp] = 1ll * gp[sp << 1] * gp[sp << 1] % ZZQ;
Pow(k, n) {
int x = gp[k << 1];
Step (i, 0, n - 1, k << 1) {
int w = 1;
For (j, 0, k - 1) {
int u = a[i + j], v = 1ll * w * a[i + j + k] % ZZQ;
a[i + j] = (u + v) % ZZQ;
a[i + j + k] = (u - v + ZZQ) % ZZQ;
w = 1ll * w * x % ZZQ;
}
}
}
}
int main() {
int i; fac[0] = inv[0] = inv[1] = spw[0] = 1;
n = read(); m = read(); s = read();
For (i, 0, m) w[i] = read();
For (i, 1, max(n, m)) fac[i] = 1ll * fac[i - 1] * i % ZZQ;
For (i, 2, max(n, m))
inv[i] = 1ll * (ZZQ - ZZQ / i) * inv[ZZQ % i] % ZZQ;
For (i, 2, max(n, m)) inv[i] = 1ll * inv[i] * inv[i - 1] % ZZQ;
For (i, 1, m) spw[i] = 1ll * spw[i - 1] * inv[s] % ZZQ;
For (i, 0, m) {
f[i] = i & 1 ? ZZQ - 1 : 1;
f[i] = 1ll * f[i] * inv[i] % ZZQ * spw[i] % ZZQ;
}
For (i, 0, m) {
if (n + (i - m) * s < 0) continue;
g[i] = qpow(i, n + (i - m) * s);
g[i] = 1ll * g[i] * inv[i] % ZZQ * inv[n + (i - m) * s] % ZZQ;
}
For (i, 0, m) h[i] = 1ll * w[i] * inv[i] % ZZQ * spw[i] % ZZQ;
while (ff <= (m << 1)) ff <<= 1, tot++;
gg = qpow(ff, ZZQ - 2);
For (i, 0, ff - 1)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << tot - 1);
FFT(ff, f, 1); FFT(ff, g, 1);
For (i, 0, ff - 1) res[i] = 1ll * f[i] * g[i] % ZZQ;
FFT(ff, res, -1);
For (i, 0, ff - 1) res[i] = 1ll * res[i] * gg % ZZQ;
For (i, 0, m) ans = (ans + 1ll * h[i] * res[m - i] % ZZQ) % ZZQ;
cout << 1ll * fac[m] * fac[n] % ZZQ * ans % ZZQ << endl;
return 0;
}