即要求计算一个集合幂级数在子集卷积意义下的
exp ≤ k f = ∑ j = 0 k f k k ! \exp_{\le k} f = \sum_{j=0}^k \frac{f^k}{k!} exp≤kf=j=0∑kk!fk
第一步自然是 fmt,然后进行每个分量上的形式幂级数计算。
考虑到
g
(
z
)
=
exp
≤
k
f
(
z
)
g(z) = \exp_{\le k} f(z)
g(z)=exp≤kf(z) 满足方程
g ′ = g f ′ − f k k ! f ′ g' = gf' - \frac{f^k}{k!} f' g′=gf′−k!fkf′
通过递推式,此问题可在 Θ ( n 2 2 n ) \Theta(n^2 2^n) Θ(n22n) 内解决。
UPD 2020.11.22: 被通解打爆了
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <cmath>
#include <ctime>
#include <algorithm>
#define LOG(FMT...) fprintf(stderr, FMT)
using namespace std;
typedef unsigned long long ll;
const int N = 21, P = 998244353;
int n, m, k, ifac = 1;
int f[1 << N][N + 1];
int a[N + 1], b[N + 1], c[N + 1], iinv[N + 1];
int norm(int x) { return x >= P ? x - P : x; }
void exGcd(int a, int b, int& x, int& y) {
if (!b) {
x = 1;
y = 0;
return;
}
exGcd(b, a % b, y, x);
y -= a / b * x;
}
int inv(int a) {
int x, y;
exGcd(a, P, x, y);
return norm(x + P);
}
int mpow(int x, int k) {
int ret = 1;
while (k) {
if (k & 1)
ret = ret * (ll)x % P;
k >>= 1;
x = x * (ll)x % P;
}
return ret;
}
int main() {
scanf("%d%d%d", &n, &m, &k);
for (int i = 1; i <= n; ++i) iinv[i] = inv(i);
for (int i = 1; i <= k; ++i) ifac = ifac * (ll)iinv[i] % P;
while (m--) {
int s;
scanf("%d", &s);
++f[s][__builtin_popcount(s)];
}
for (int i = 0; i < n; ++i)
for (int s = 0; s < 1 << n; ++s)
if (!(s >> i & 1))
for (int j = 0; j <= n; ++j) f[s | 1 << i][j] = norm(f[s | 1 << i][j] + f[s][j]);
int ans = 0;
for (int s = 0; s < 1 << n; ++s) {
int lead = -1;
for (int i = 0; i <= n; ++i)
if (f[s][i]) {
lead = i;
break;
}
if (lead == -1 || lead * k > n)
continue;
int in = inv(f[s][lead]), pw = mpow(f[s][lead], k);
memset(a, 0, sizeof(a));
for (int i = lead; i <= n; ++i) a[i - lead] = f[s][i] * (ll)in % P;
for (int i = 0; i < n - lead * k; ++i) b[i] = a[i + 1] * (ll)(i + 1) % P;
b[n] = 0;
for (int i = 0; i <= n - lead * k; ++i) {
ll v = b[i];
for (int j = 1; j <= i; ++j) {
v += (P - a[j]) * (ll)c[i - j];
if ((j & 15) == 15)
v %= P;
}
c[i] = v % P;
}
for (int i = 0; i <= n - lead * k; ++i) c[i] = c[i] * (ll)k % P;
a[0] = 1;
for (int i = 1; i <= n - lead * k; ++i) {
ll v = 0;
for (int j = 0; j < i; ++j) {
v += a[j] * (ll)c[i - 1 - j];
if ((j & 15) == 15)
v %= P;
}
a[i] = v % P * (ll)iinv[i] % P;
}
for (int i = lead * k; i <= n; ++i) b[i] = a[i - lead * k] * (ll)pw % P * ifac % P;
memset(c, 0, sizeof(c));
for (int i = 0; i < n; ++i) a[i] = f[s][i + 1] * (ll)(i + 1) % P;
for (int i = lead * k; i <= n; ++i) {
ll v = 0;
for (int j = 0; j <= i - lead * k; ++j) {
v += b[i - j] * (ll)a[j];
if ((j & 15) == 15)
v %= P;
}
c[i] = v % P;
}
for (int i = 0; i < n; ++i) b[i] = f[s][i + 1] * (ll)(i + 1) % P;
b[n] = 0;
a[0] = 1;
for (int i = 1; i <= n; ++i) {
ll v = norm(P - c[i - 1]);
for (int j = 0; j < i; ++j) {
v += a[j] * (ll)b[i - 1 - j];
if ((j & 15) == 15)
v %= P;
}
a[i] = v % P * iinv[i] % P;
}
if (__builtin_parity(s ^ ((1 << n) - 1)))
ans = norm(ans - a[n] + P);
else
ans = norm(ans + a[n]);
}
printf("%d\n", ans);
return 0;
}