题意
给定一个长度为 n n n 的字符串 S S S,计算有多少种非空子序列的排列,对 998244353 998244353 998244353 取模。
1 ≤ n ≤ 5 × 1 0 3 1 \le n \le 5\times 10^3 1≤n≤5×103
分析:
设
26
26
26 个英文字母每个字母
u
u
u 的生成函数为(因为要计算排列,所以是
EGF
\textbf{EGF}
EGF)
1
+
1
1
!
x
+
1
2
!
x
2
+
⋯
+
1
c
n
t
u
!
x
c
n
t
u
1 + \frac{1}{1!}x + \frac{1}{2!}x^2 + \cdots + \frac{1}{cnt_{u}!}x^{cnt_{u}}
1+1!1x+2!1x2+⋯+cntu!1xcntu
其中
c
n
t
u
cnt_{u}
cntu 为字母在
S
S
S 中出现次数。
那么最终所有方案为
∏
u
=
1
26
(
1
+
1
1
!
x
+
1
2
!
x
2
+
⋯
+
1
c
n
t
u
!
x
c
n
t
u
)
\prod_{u = 1} ^{26}(1 + \frac{1}{1!}x + \frac{1}{2!}x^2 + \cdots + \frac{1}{cnt_{u}!}x^{cnt_{u}})
u=1∏26(1+1!1x+2!1x2+⋯+cntu!1xcntu)
记所有方案为
F
(
x
)
F(x)
F(x),那么计算子序列的方案数就是把字符串长度为
1
∼
n
1 \sim n
1∼n 的所有方案乘长度的阶乘加起来,即
∑
i
=
1
n
i
!
×
[
x
i
]
F
(
x
)
\sum_{i = 1}^{n}i! \times [x^i]F(x)
i=1∑ni!×[xi]F(x)
因为最多只有
26
26
26 个字母,所以直接每次暴力
NTT
\text{NTT}
NTT 即可,当然也可以分治
FFT
\text{FFT}
FFT
官方标程是 DP \text{DP} DP,可以发现多项式比标程快一些。
代码:
#include <bits/stdc++.h>
#define int long long
#define poly vector<int>
#define len(x) ((int)x.size())
using namespace std;
const int N = 2e4 + 5, M = 35, g = 3, ginv = 332748118, mod = 998244353;
int n, m, rev[N], lim, liminv, cnt[35], fact[N], infact[N], res;
string str;
poly ans, tmp;
int qmi(int a, int b) {
int res = 1;
while (b) {
if (b & 1) res = res * a % mod;
a = a * a % mod;
b >>= 1;
}
return res;
}
void NTT(poly &f, int op) {
for (int i = 0; i < lim; i ++) {
if (i < rev[i]) swap(f[i], f[rev[i]]);
}
for (int mid = 1; mid < lim; mid <<= 1) {
int gn = qmi(op == 1 ? g : ginv, (mod - 1) / (mid << 1));
for (int i = 0; i < lim; i += mid * 2) {
for (int j = 0, g0 = 1; j < mid; j ++, g0 = g0 * gn % mod) {
int x = f[i + j], y = g0 * f[i + j + mid] % mod;
f[i + j] = (x + y) % mod, f[i + j + mid] = (x - y + mod) % mod;
}
}
}
if (op == -1) {
for (int i = 0; i < lim; i ++) f[i] = f[i] * liminv % mod;
}
}
poly operator * (poly f, poly g) {
int n = len(f) + len(g) - 1;
for (lim = 1; lim < n; lim <<= 1); liminv = qmi(lim, mod - 2);
for(int i = 0; i < lim; i ++) rev[i] = (rev[i >> 1] >> 1) | (i & 1 ? lim >> 1 : 0);
f.resize(lim), g.resize(lim);
NTT(f, 1), NTT(g, 1);
for (int i = 0; i < lim; i ++) f[i] = f[i] * g[i] % mod;
NTT(f, -1), f.resize(n);
return f;
}
signed main() {
fact[0] = infact[0] = 1;
for (int i = 1; i < N; i ++) {
fact[i] = fact[i - 1] * i % mod;
infact[i] = infact[i - 1] * qmi(i, mod - 2) % mod;
}
cin >> str;
for (int i = 0; i < str.size(); i ++) {
cnt[str[i] - 'a'] ++;
}
ans.resize(cnt[0] + 1);
for (int i = 0; i <= cnt[0]; i ++) {
ans[i] = infact[i];
}
for (int i = 1; i < 26; i ++) {
tmp.resize(cnt[i] + 1);
for (int j = 0; j <= cnt[i]; j ++) {
tmp[j] = infact[j];
}
ans = ans * tmp;
}
for (int i = 1; i < len(ans); i ++) {
res = (res + fact[i] * ans[i] % mod) % mod;
}
cout << res << endl;
}