在应用中,AC自动机大多数是与DP结合起来用的,当然也有其他类型的应用。
最经典的应用之一:
给出一些串,这些串是“病毒串”,问有多少种长度为n且不包含病毒串(或者至少出现一次)的字符串。
这类问题中,病毒串长度一般很小,总长度一般不超过50,而n却很大,一般在10^9以上。
如果只有一个病毒串,那么我们只需要KMP就好了,比如
我们先求出A[i][j],表示病毒串从i这个前缀添加一个字符,变为j这个前缀的方案数,这个可以先求出next数组,然后用类似一个匹配去求。
设dp[i][j]表示长度为i,末尾的j个字符为病毒串的前缀的方案数,那么有转移
dp[i][j] = ∑(dp[i][k] * A[k][j]),0 <= k < len
这是一个线性递推,我们可以用矩阵快速幂加速,转移矩阵即A数组。
答案为∑dp[len][i],0 <= i < len
直接上代码。
/* Pigonometry */
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int maxn = 25;
int n, m, p, s[maxn], fail[maxn];
struct _mat {
int num[maxn][maxn];
} E, trans;
inline int iread() {
int f = 1, x = 0; char ch = getchar();
for(; ch < '0' || ch > '9'; ch = getchar()) f = ch == '-' ? -1 : 1;
for(; ch >= '0' && ch <= '9'; ch = getchar()) x = x * 10 + ch - '0';
return f * x;
}
inline _mat mul(_mat &A, _mat &B) {
_mat C;
for(int i = 0; i < m; i++) for(int j = 0; j < m; j++) {
C.num[i][j] = 0;
for(int k = 0; k < m; k++) C.num[i][j] = (C.num[i][j] + A.num[i][k] * B.num[k][j]) % p;
}
return C;
}
inline _mat qpow(_mat &A, int n) {
_mat ans = E;
for(_mat t = A; n; n >>= 1, t = mul(t, t)) if(n & 1) ans = mul(ans, t);
return ans;
}
char str[maxn];
int main() {
n = iread(); m = iread(); p = iread(); scanf("%s", str + 1);
for(int i = 1; i <= m; i++) s[i] = str[i] - '0';
for(int i = 2, j = 0; i <= m; fail[i++] = j) {
for(; j != 0 && s[j + 1] != s[i]; j = fail[j]);
if(s[j + 1] == s[i]) j++;
}
for(int i = 0; i < m; i++) for(int j = 0; j <= 9; j++) {
int k = i;
for(; k != 0 && s[k + 1] != j; k = fail[k]);
if(s[k + 1] == j) k++;
trans.num[i][k] = (trans.num[i][k] + 1) % p;
}
for(int i = 0; i < m; i++) E.num[i][i] = 1;
_mat res = qpow(trans, n);
int ans = 0;
for(int i = 0; i < m; i++) ans = (ans + res.num[0][i]) % p;
printf("%d\n", ans);
return 0;
}
如果有多个病毒串,这时候就要用AC自动机了,思路类似。
这个题数据范围较小,可以不用矩阵快速幂。
题目求的是至少包含一个单词的方案数,我们转化为 总方案数 - 一个单词都不包含的方案数。
前者是26^m,一个快速幂就好了,后者用dp求。
设dp[i][j],表示字符串长度为i时,在AC自动机上的第j个节点,不包含病毒串的方案数,那么有转移
dp[i][son[j][k]] += dp[i - 1][j],0 <= k < 26,且son[j][k]不为病毒串的结尾(AC自动机插入和求fail数组时可以预处理出)。
答案为26^m - ∑dp[m][i]
/* Pigonometry */
#include <cstdio>
#include <cstring>
#define cls(a, x) memset(a, x, sizeof(a))
using namespace std;
const int maxn = 6005, maxm = 105, p = 10007, maxq = 10000;
int dp[maxm][maxn], q[maxq];
struct _acm {
int son[maxn][26], fail[maxn], acmcnt;
bool flag[maxn];
void init() {
cls(son, 0); for(int i = 0; i < maxn; i++) fail[i] = flag[i] = 0;
acmcnt = 0;
}
void insert(char *s) {
int now = 0, len = strlen(s);
for(int i = 0; i < len; i++) {
int &pos = son[now][s[i] - 'A'];
if(!pos) pos = ++acmcnt;
now = pos;
}
flag[now] = 1;
}
void getfail() {
int h = 0, t = 0;
for(int i = 0; i < 26; i++) if(son[0][i]) q[t++] = son[0][i];
while(h != t) {
int u = q[h++];
for(int i = 0; i < 26; i++)
if(!son[u][i]) son[u][i] = son[fail[u]][i];
else {
fail[q[t++] = son[u][i]] = son[fail[u]][i];
flag[son[u][i]] |= flag[fail[son[u][i]]];
}
}
}
} acm;
inline int qpow(int a, int n) {
int ans = 1;
for(int t = a; n; n >>= 1, t = t * t % p) if(n & 1) ans = ans * t % p;
return ans;
}
char str[maxm];
int main() {
int n, m; scanf("%d%d", &n, &m);
for(int i = 1; i <= n; i++) {
scanf("%s", str);
acm.insert(str);
}
acm.getfail();
dp[0][0] = 1;
for(int i = 1; i <= m; i++) for(int j = 0; j <= acm.acmcnt; j++) for(int k = 0; k < 26; k++)
if(!acm.flag[acm.son[j][k]]) dp[i][acm.son[j][k]] = (dp[i][acm.son[j][k]] + dp[i - 1][j]) % p;
int ans = qpow(26, m);
for(int i = 0; i <= acm.acmcnt; i++) ans = (ans - dp[m][i] + p) % p;
printf("%d\n", ans);
return 0;
}
上个题如果m非常大,那么就要用矩阵快速幂了。换种思路(其实还是线性递推)。
设A[i][j]表示从AC自动机上的第i个节点添加一个字符到第j个节点的方案数。
A数组可以枚举i,然后枚举i的儿子求出。
然后对A跑矩阵快速幂就好了。
/* Pigonometry */
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
typedef long long LL;
const int maxn = 150, maxnode = maxn, maxd = 4, maxq = 10000, p = 100000;
int n, m, id[26], q[maxq], size;
struct _acm {
int son[maxnode][maxd], acmcnt, fail[maxnode];
bool flag[maxnode];
void init() {
memset(son, 0, sizeof(son)); acmcnt = 0; memset(fail, 0, sizeof(fail)); memset(flag, 0, sizeof(flag));
}
void insert(string s) {
int now = 0, len = s.size();
for(int i = 0; i < len; i++) {
int &pos = son[now][id[s[i] - 'A']];
if(!pos) pos = ++acmcnt;
now = pos;
}
flag[now]++;
}
void getfail() {
int h = 0, t = 0;
for(int i = 0; i < 4; i++) if(son[0][i]) q[t++] = son[0][i];
while(h != t) {
int u = q[h++];
for(int i = 0; i < 4; i++)
if(!son[u][i]) son[u][i] = son[fail[u]][i];
else {
fail[q[t++] = son[u][i]] = son[fail[u]][i];
flag[son[u][i]] |= flag[fail[son[u][i]]];
}
}
}
} acm;
struct _matrix {
int num[maxn][maxn];
} trans, E;
inline _matrix matmul(_matrix A, _matrix B) {
_matrix ans;
for(int i = 0; i < size; i++) for(int j = 0; j < size; j++) {
ans.num[i][j] = 0;
for(int k = 0; k < size; k++) ans.num[i][j] = (ans.num[i][j] + ((LL)A.num[i][k] * B.num[k][j]) % p) % p;
}
return ans;
}
_matrix matqpow(_matrix A, int n) {
_matrix s = E;
for(_matrix t = A; n; n >>= 1, t = matmul(t, t)) if(n & 1) s = matmul(s, t);
return s;
}
int main() {
ios::sync_with_stdio(false);
cin >> m >> n;
acm.init();
id['A' - 'A'] = 0; id['C' - 'A'] = 1; id['G' - 'A'] = 2; id['T' - 'A'] = 3;
for(int i = 1; i <= m; i++) {
string str; cin >> str;
acm.insert(str);
}
acm.getfail();
size = acm.acmcnt + 1;
for(int i = 0; i < size; i++) E.num[i][i] = 1;
for(int i = 0; i < size; i++) if(!acm.flag[i])
for(int j = 0; j < 4; j++) if(!acm.flag[acm.son[i][j]])
trans.num[i][acm.son[i][j]]++;
_matrix res = matqpow(trans, n);
int ans = 0;
for(int i = 0; i < size; i++) ans = (ans + res.num[0][i]) % p;
printf("%d\n", ans);
return 0;
}
另外还有一个比较有趣的拓展,这个
这个题求长度不小于m的字符串的方案数。
即求A^1 + A^2 + A^3 + ... + A^m
可以构造一个分块矩阵,长这样:
A E
0 E
其中E为单位矩阵,对这个分块矩阵跑快速幂就好了。
/* Pigonometry */
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
typedef unsigned long long ULL;
const int maxn = 105, maxnode = 40, maxd = 26, maxq = maxn;
ULL n, l;
int size, q[maxq];
struct _acm {
int son[maxnode][maxd], acmcnt, fail[maxnode];
bool flag[maxnode];
void init() {
memset(son, 0, sizeof(son)); acmcnt = 0; memset(fail, 0, sizeof(fail)); memset(flag, 0, sizeof(flag));
}
void insert(string s) {
int now = 0, len = s.size();
for(int i = 0; i < len; i++) {
int &pos = son[now][s[i] - 'a'];
if(!pos) pos = ++acmcnt;
now = pos;
}
flag[now]++;
}
void getfail() {
int h = 0, t = 0;
for(int i = 0; i < 26; i++) if(son[0][i]) q[t++] = son[0][i];
while(h != t) {
int u = q[h++];
for(int i = 0; i < 26; i++)
if(!son[u][i]) son[u][i] = son[fail[u]][i];
else {
fail[q[t++] = son[u][i]] = son[fail[u]][i];
flag[son[u][i]] |= flag[fail[son[u][i]]];
}
}
}
} acm;
struct _matrix {
ULL num[maxn][maxn];
} trans, E, two;
_matrix matmul(_matrix A, _matrix B) {
_matrix ans;
for(int i = 0; i < size; i++) for(int j = 0; j < size; j++) {
ans.num[i][j] = 0;
for(int k = 0; k < size; k++) ans.num[i][j] += A.num[i][k] * B.num[k][j];
}
return ans;
}
_matrix matqpow(_matrix A, ULL n) {
_matrix s = E;
for(_matrix t = A; n; n >>= 1, t = matmul(t, t)) if(n & 1) s = matmul(s, t);
return s;
}
int main() {
ios::sync_with_stdio(false);
for(int i = 0; i < maxn; i++) E.num[i][i] = 1;
while(cin >> n >> l) {
acm.init();
for(int i = 1; i <= n; i++) {
string str; cin >> str;
acm.insert(str);
}
acm.getfail();
memset(trans.num, 0, sizeof(trans.num));
for(int i = 0; i <= acm.acmcnt; i++) if(!acm.flag[i])
for(int j = 0; j < 26; j++) if(!acm.flag[acm.son[i][j]])
trans.num[i][acm.son[i][j]]++;
for(int i = 1; i <= acm.acmcnt + 1; i++) trans.num[i - 1][acm.acmcnt + i] = trans.num[acm.acmcnt + i][acm.acmcnt + i] = 1;
size = (acm.acmcnt << 1) + 2;
_matrix res = matqpow(trans, l);
ULL ans = 0;
for(int i = 0; i < size; i++) if(!acm.flag[i]) ans += res.num[0][i];
memset(two.num, 0, sizeof(two.num));
two.num[0][0] = 26; two.num[0][1] = 1;
two.num[1][0] = 0; two.num[1][1] = 1;
size = 2;
_matrix ret = matqpow(two, l);
ULL tot = ret.num[0][0] + ret.num[0][1];
cout << tot - ans << endl;
}
return 0;
}
还有一些其他类型的DP,比如
【Codeforces86C: Genetic engineering】
有m个模板串,要求字符串中每一个字符都至少被一个模板串覆盖,求长度为n的字符串个数。
设cover[i]表示AC自动机上第i个节点,以这个节点为结尾的模板串的长度的最大值。
设dp[i][j][k]表示长度为i,在AC自动机上第j个节点,结尾的k个字符未匹配的方案数。有转移
dp[i + 1][u][0] += dp[i][j][k],u为i的儿子,且cover[u] >= k + 1
dp[i + 1][u][k + 1] += dp[i][j][k],u为i的儿子,且cover[u] < k + 1
/* Pigonometry */
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
const int maxn = 1005, maxd = 5, maxnode = maxn, maxq = maxnode, p = 1000000009;
int n, m, id[30], dp[maxn][105][15], q[maxq];
struct _acm {
int son[maxnode][maxd], acmcnt, fail[maxnode], cover[maxnode];
void init() {
memset(son, 0, sizeof(son)); acmcnt = 0; memset(fail, 0, sizeof(fail)); memset(cover, 0, sizeof(cover));
}
void insert(string s) {
int now = 0, len = s.size();
for(int i = 0; i < len; i++) {
int &pos = son[now][id[s[i] - 'A']];
if(!pos) pos = ++acmcnt;
now = pos;
}
cover[now] = len;
}
void getfail() {
int h = 0, t = 0;
for(int i = 0; i < 4; i++) if(son[0][i]) q[t++] = son[0][i];
while(h != t) {
int u = q[h++];
for(int i = 0; i < 4; i++) {
int &pos = son[u][i];
if(!pos) pos = son[fail[u]][i];
else {
fail[q[t++] = pos] = son[fail[u]][i];
cover[pos] = max(cover[pos], cover[fail[pos]]);
}
}
}
}
} acm;
int main() {
ios::sync_with_stdio(false);
cin >> n >> m;
acm.init();
id['A' - 'A'] = 0; id['C' - 'A'] = 1; id['G' - 'A'] = 2; id['T' - 'A'] = 3;
for(int i = 1; i <= m; i++) {
string str; cin >> str;
acm.insert(str);
}
acm.getfail();
dp[0][0][0] = 1;
for(int i = 0; i < n; i++) for(int j = 0; j <= acm.acmcnt; j++) for(int k = 0; k <= 10; k++) if(dp[i][j][k])
for(int l = 0; l < 4; l++) {
int u = acm.son[j][l];
if(acm.cover[u] >= k + 1) dp[i + 1][u][0] = (dp[i + 1][u][0] + dp[i][j][k]) % p;
else dp[i + 1][u][k + 1] = (dp[i + 1][u][k + 1] + dp[i][j][k]) % p;
}
int ans = 0;
for(int i = 0; i <= acm.acmcnt; i++) ans = (ans + dp[n][i][0]) % p;
printf("%d\n", ans);
return 0;
}
给出n个病毒串,和一个字符串,问至少修改多少个字符,使得这个字符串不包含病毒串。
设dp[i][j]表示长度为i,在AC自动机上第j个节点,至少修改了多少字符。
枚举j的儿子u,如果u和原字符串的字符不相同,那么就要修改。
dp[i][u] = min(dp[i][u], dp[i - 1][j] + [k与原字符串的字符不相同])
/* Pigonometry */
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
const int maxn = 1005, maxd = 4, maxnode = maxn, maxq = maxnode;
int n, id[26], dp[maxn][maxnode], q[maxq];
struct _acm {
int son[maxnode][maxd], acmcnt, fail[maxnode];
bool flag[maxnode];
void init() {
memset(son, 0, sizeof(son)); acmcnt = 0; memset(fail, 0, sizeof(fail)); memset(flag, 0, sizeof(flag));
}
void insert(string s) {
int now = 0, len = s.size();
for(int i = 0; i < len; i++) {
int &pos = son[now][id[s[i] - 'A']];
if(!pos) pos = ++acmcnt;
now = pos;
}
flag[now]++;
}
void getfail() {
int h = 0, t = 0;
for(int i = 0; i < 4; i++) if(son[0][i]) q[t++] = son[0][i];
while(h != t) {
int u = q[h++];
for(int i = 0; i < 4; i++) {
int &v = son[u][i];
if(!v) v = son[fail[u]][i];
else {
fail[q[t++] = v] = son[fail[u]][i];
flag[v] |= flag[fail[v]];
}
}
}
}
int getans(string s) {
memset(dp, 0x3f, sizeof(dp));
dp[0][0] = 0;
int len = s.size();
for(int i = 1; i <= len; i++) for(int j = 0; j <= acmcnt; j++) if(dp[i - 1][j] != 0x3f3f3f3f)
for(int k = 0; k < 4; k++) {
int u = son[j][k];
if(!flag[u]) dp[i][u] = min(dp[i][u], dp[i - 1][j] + (k != id[s[i - 1] - 'A']));
}
int ans = 0x3f3f3f3f;
for(int i = 0; i <= acmcnt; i++) if(!flag[i]) ans = min(ans, dp[len][i]);
if(ans == 0x3f3f3f3f) ans = -1;
return ans;
}
} acm;
int main() {
ios::sync_with_stdio(false);
id['A' - 'A'] = 0; id['C' - 'A'] = 1; id['G' - 'A'] = 2; id['T' - 'A'] = 3;
for(int cas = 1; ; cas++) {
cin >> n;
if(!n) break;
acm.init();
for(int i = 1; i <= n; i++) {
string str; cin >> str;
acm.insert(str);
}
acm.getfail();
string str; cin >> str;
printf("Case %d: %d\n", cas, acm.getans(str));
}
return 0;
}