题目链接:https://www.luogu.com.cn/problem/P4491
这显然是一道计数题
我们先设
f
[
i
]
f[i]
f[i]代表钦定恰好出现
S
S
S次的颜色有
i
i
i种的方案数
则:
f
[
k
]
=
(
m
k
)
(
n
s
k
)
P
(
n
,
s
,
s
,
s
.
.
.
s
)
∗
(
m
−
k
)
n
−
s
k
f[k]= \binom{m}{k} \binom{n}{sk}P(n,s,s,s...s)*(m-k)^{n-sk}
f[k]=(km)(skn)P(n,s,s,s...s)∗(m−k)n−sk
那么由二项式反演可得:
g
[
k
]
=
∑
i
=
k
m
(
−
1
)
i
−
k
(
i
k
)
f
[
i
]
g[k]=\sum_{i=k}^m (-1)^{i-k}\binom{i}{k}f[i]
g[k]=i=k∑m(−1)i−k(ki)f[i]
其中
g
[
k
]
g[k]
g[k]代表恰好出现
S
S
S次的颜色有
k
k
k种的方案数
那么,最后的答案: a n s = ∑ i = 0 m w [ i ] ∗ g [ i ] ans=\sum_{i=0}^mw[i]*g[i] ans=i=0∑mw[i]∗g[i]
如果直接求
g
[
i
]
g[i]
g[i]是
O
(
n
2
)
O(n^2)
O(n2)的
我们将
g
[
i
]
g[i]
g[i]中组合数拆开:
g
[
i
]
=
(
−
1
)
i
m
!
n
!
i
!
∑
k
=
i
m
(
−
1
)
k
(
m
−
k
)
n
−
s
k
(
k
−
i
)
!
(
m
−
k
)
!
(
n
−
s
k
)
!
(
s
!
)
k
g[i]=(-1)^i \ \frac{m!n!}{i!}\sum_{k=i}^m(-1)^k \ \frac{(m-k)^{n-sk}}{(k-i)! \ (m-k)! \ (n-sk)! \ (s!)^k}
g[i]=(−1)i i!m!n!k=i∑m(−1)k (k−i)! (m−k)! (n−sk)! (s!)k(m−k)n−sk
我们把带 k − 1 k-1 k−1的项设为 b k − i b_{k-i} bk−i, a k a_k ak代表含 k k k的其余项
那么 ∑ \sum ∑后面的项可以变为: ∑ k − ( k − i ) = i k < = m a k ∗ b k − i \sum_{k-(k-i)=i}^{k<=m}a_k*b_{k-i} k−(k−i)=i∑k<=mak∗bk−i
这是一个差卷积的形式,将 b b b翻转后 N T T NTT NTT,得到的数组的后 m m m项就是最终的 g [ i ] g[i] g[i]的项
C o d e Code Code
#include <iostream>
#include <cstdlib>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define int long long
const int MAXM = 2e5, MAXN = 1e7, Mod = 1004535809, G = 3;
int a[MAXM + 10], p[MAXN + 10], inv[MAXM + 10], invG;
int b[MAXM + 10], c[MAXM + 10];
inline int read();
int fastpow(int x, int p){
if (!x) return 0;
int ans = 1;
while (p){
if (p & 1) ans = ans * x % Mod;
x = x * x % Mod;
p >>= 1;
}
return ans;
}
void init(int m, int n){
inv[0] = p[0] = 1;
for (register int i = 1; i <= m; ++i){
p[i] = p[i - 1] * i % Mod;
inv[i] = fastpow(p[i], Mod - 2);
}
for (register int i = m + 1; i <= n; ++i)
p[i] = p[i - 1] * i % Mod;
}
namespace NTT{
void NTT(int *f, int *tree, int n, int op){
for (register int i = 1; i < n; ++i)
if (i > tree[i]) swap(f[i], f[tree[i]]);
for (register int p = 2; p <= n; p <<= 1){
int len = p >> 1;
int rg = fastpow(op? invG : G, (Mod - 1) / p);
for (register int k = 0; k < n; k += p){
int buf = 1;
for (register int l = k; l < k + len; ++l){
int tmp = f[l + len] * buf % Mod;
f[l + len] = f[l] - tmp;
if (f[l + len] < 0) f[l + len] += Mod;
f[l] = (f[l] + tmp) % Mod;
buf = buf * rg % Mod;
}
}
}
if (op){
for (register int i = 0; i <= n; ++i)
f[i] = f[i] * op % Mod;
}
}
void MUL(int *a, int *b, int n, int m){
static int f[MAXM * 2 + 10], g[MAXM * 2 + 10], tree[MAXM * 2 + 10];
int invn;
for (register int i = 0; i <= n; ++i) f[i] = a[i];
for (register int i = 0; i <= m; ++i) g[i] = b[i];
m += n; n = 1;
while (n <= m) n <<= 1;
invn = fastpow(n, Mod - 2), invG = fastpow(G, Mod - 2);
for (register int i = 1; i < n; ++i)
tree[i] = (tree[i >> 1] >> 1) | ((i & 1)? n >> 1 : 0);
NTT(f, tree, n, 0), NTT(g, tree, n, 0);
for (register int i = 0; i <= n; ++i) f[i] = f[i] * g[i] % Mod;
NTT(f, tree, n, invn);
for (register int i = 0; i <= m; ++i) a[i] = f[i];
for (register int i = 0; i <= n; ++i) f[i] = g[i] = 0;
}
}
/*
void DEBUG(){
static int a[MAXN + 10], b[MAXN + 10];
int n = read(), m = read();
for (register int i = 0; i <= n; ++i) a[i] = read();
for (register int i = 0; i <= m; ++i) b[i] = read();
NTT::MUL(a, b, n, m);
for (register int i = 0; i <= n + m; ++i)
cerr << a[i] << " ";
cerr << endl;
}
*/
signed main(){
freopen ("std.in", "r", stdin);
freopen ("std.out", "w", stdout);
//DEBUG();
int n, m, s, sum;
n = read(), m = read(), s = read();
for (register int i = 0; i <= m; ++i) a[i] = read();
init(max(m, s), n); sum = 1;
for (register int i = 0; i <= m; ++i){
int x = i % 2? Mod - 1 : 1;
if (n >= s * i) b[i] = x * inv[m - i] % Mod * fastpow(p[n - s * i], Mod - 2) % Mod * sum % Mod * fastpow(m - i, n - s * i) % Mod;
c[i] = inv[i];
sum = sum * inv[s] % Mod;
}
reverse(c, c + m + 1);
NTT::MUL(b, c, m, m);
int ans = 0;
for (register int i = 0; i <= m; ++i){
int x = i % 2? Mod - 1 : 1;
ans = (ans + x * p[n] % Mod * p[m] % Mod * a[i] % Mod * inv[i] % Mod * b[i + m]) % Mod;
}
printf("%lld\n", ans);
/*
static int tree[MAXN + 10];
for (register int i = 1; i < n; ++i)
tree[i] = (tree[i >> 1] >> 1) | ((i & 1)? n >> 1 : 0);
*/
return 0;
}
inline int read(){
int x = 0;
char c = getchar();
while (!isdigit(c)) c = getchar();
while (isdigit(c))x = (x << 1) + (x << 3) + (c & 15), c = getchar();
return x;
}