「算法笔记」AC自动机

学习笔记

这篇Blog讲的很清楚。

下面是个人的见解。

AC自动机 = trie + KMP

每一个节点有一个fail指针和alphabet个next指针。

插入:同trie。
计算fail指针:bfs版的KMP。
字符串匹配:基本上与KMP相同。
详见代码。

模版(指针版)
// Luogu P3808 AC自动机模版
#include <cstdio>
#include <cstring>
const int maxn = 1000005;
const int maxm = 26;
struct node {
    int cnt;
    node *fail, *next[26];
    void init() {
        cnt = 0, fail = NULL;
        memset(next, NULL, sizeof(next));
    }
} nodes[maxn], *que[maxn];
struct dfa {
    int e;
    node *root;
    node* _add() {
        nodes[e].init();
        return &nodes[e++];
    }
    void init() {
        e = 0;
        root = _add();
    }
    void insert(char *s) {
        node *u = root;
        for (int i = 0, ch; s[i]; i++) {
            ch = s[i] - 'a';
            if (u -> next[ch] == NULL) {
                u -> next[ch] = _add();
            }
            u = u -> next[ch]; 
        }
        u -> cnt++;
    }
    void getfail() {
        root -> fail = root;
        int h = 0, t = 0;
        node *u;
        for (int i = 0; i < maxm; i++) {
            if (root -> next[i]) {
                root -> next[i] -> fail = root;
                que[t++] = root -> next[i];
            } else {
                root -> next[i] = root;
            }
        }
        while (h < t) {
            u = que[h++];
            for (int i = 0; i < maxm; i++) {
                if (u -> next[i]) {
                    u -> next[i] -> fail = u -> fail -> next[i];
                    que[t++] = u -> next[i];
                } else {
                    u -> next[i] = u -> fail -> next[i];
                }
            }
        }
    }
    int match(char *s) {
        int ch, ans = 0;
        node *u = root, *tmp;
        for (int i = 0; s[i]; i++) {
            ch = s[i] - 'a';
            tmp = u = u -> next[ch];
            while (tmp != root && tmp -> cnt != -1) {
                ans += tmp -> cnt;
                tmp -> cnt = -1;
                tmp = tmp -> fail;
            }
        }
        return ans;
    }
} ac;
int n;
char s[maxn];
int main() {
    ac.init();
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) {
        scanf("%s", s);
        ac.insert(s);
    }
    ac.getfail();
    scanf("%s", s);
    printf("%d\n", ac.match(s));
    return 0;
}
模版(非指针)
// BZOJ 4327 玄武密码
#include <cstdio>
#include <queue>
#include <cstring>
using namespace std;
const int maxn = 100005;
const int maxm = 10000005;
bool vis[maxm];
char s[maxm], t[maxn];
int m, n, b[1 << 8];
int pos[maxn], len[maxn];
int tot, nxt[maxm][4], pre[maxm], lst[maxm];
void prework() {
    tot = 1;
    pre[1] = 1;
}
int insert() {
    int u = 1;
    for (int i = 1; t[i]; i++) {
        if (!nxt[u][b[t[i]]]) {
            nxt[u][b[t[i]]] = ++tot;
            lst[tot] = u;
        }
        u = nxt[u][b[t[i]]];
    }
    return u;
}
void getfail() {
    queue<int> que;
    for (int i = 0; i < 4; i++) {
        if (nxt[1][i]) {
            pre[nxt[1][i]] = 1;
            que.push(nxt[1][i]);
        } else {
            nxt[1][i] = 1;
        }
    }
    while (!que.empty()) {
        int u = que.front();
        que.pop();
        for (int i = 0; i < 4; i++) {
            if (nxt[u][i]) {
                pre[nxt[u][i]] = nxt[pre[u]][i];
                que.push(nxt[u][i]);
            } else {
                nxt[u][i] = nxt[pre[u]][i];
            }
        }
    }
}
void match() {
    int u = 1, v;
    vis[u] = 1;
    for (int i = 1; s[i]; i++) {
        u = v = nxt[u][b[s[i]]];
        while (!vis[v]) {
            vis[v] = 1;
            v = pre[v];
        }
    }
}
int main() {
    scanf("%d %d", &m, &n);
    scanf("%s", s + 1);
    b['E'] = 0, b['S'] = 1;
    b['W'] = 2, b['N'] = 3;
    prework();
    for (int i = 1; i <= n; i++) {
        scanf("%s", t + 1);
        pos[i] = insert();
        len[i] = strlen(t + 1);
    }
    getfail();
    match();
    for (int i = 1; i <= n; i++) {
        int u = pos[i], c = 0;
        while (!vis[u]) {
            u = lst[u];
            c++;
        }
        printf("%d\n", len[i] - c);
    }
    return 0;
}

下文中记自动机节点数为 tot t o t


应用一:朴素字符串匹配

这个没什么好说的······

只要套模版即可。


应用二:动态规划

AC自动机上也是可以做动态规划的。

例题一:BZOJ 1030 文本生成器

dp[i][j][1 / 0] d p [ i ] [ j ] [ 1   /   0 ] 表示长度为 i i 的字符串,匹配到了AC自动机的第j个节点,有 / 没有包含关键词的字符串数量。

#include <cstdio>
#include <cstring>
#include <queue>
using namespace std;
const int maxn = 105;
const int maxm = 6005;
const int mod = 10007;
char s[maxn];
int n, m, dp[maxn][maxm][2];
int tot, nxt[maxm][26], pre[maxm], val[maxm];
void prework() {
    tot = 1;
    memset(nxt, 0, sizeof(nxt));
    memset(pre, 0, sizeof(pre));
    memset(val, 0, sizeof(val));
    pre[1] = 1;
}
void insert() {
    int u = 1;
    for (int i = 1; s[i]; i++) {
        if (!nxt[u][s[i] - 'A']) {
            nxt[u][s[i] - 'A'] = ++tot;
        }
        u = nxt[u][s[i] - 'A'];
    }
    val[u] = 1;
}
void getfail() {
    queue<int> que;
    for (int i = 0; i < 26; i++) {
        if (!nxt[1][i]) {
            nxt[1][i] = 1;
        } else {
            pre[nxt[1][i]] = 1;
            que.push(nxt[1][i]);
        }
    }
    for (; !que.empty(); ) {
        int u = que.front();
        que.pop();
        for (int i = 0; i < 26; i++) {
            if (!nxt[u][i]) {
                nxt[u][i] = nxt[pre[u]][i];
            } else {
                pre[nxt[u][i]] = nxt[pre[u]][i];
                que.push(nxt[u][i]);
            }
        }
        val[u] |= val[pre[u]];
    }
}
void update(int &x, int y) {
    x += y, x -= x >= mod ? mod : 0;
}
int main() {
    prework();
    scanf("%d %d", &n, &m);
    for (int i = 1; i <= n; i++) {
        scanf("%s", s + 1);
        insert();
    }
    getfail();
    dp[0][1][0] = 1;
    for (int i = 1; i <= m; i++) {
        for (int j = 1; j <= tot; j++) {
            for (int k = 0; k < 26; k++) {
                if (val[nxt[j][k]]) {
                    update(dp[i][nxt[j][k]][1], (dp[i - 1][j][0] + dp[i - 1][j][1]) % mod);
                } else {
                    update(dp[i][nxt[j][k]][0], dp[i - 1][j][0]);
                    update(dp[i][nxt[j][k]][1], dp[i - 1][j][1]);
                }
            }
        }
    }
    int res = 0;
    for (int i = 1; i <= tot; i++) {
        update(res, dp[m][i][1]);
    }
    printf("%d\n", res);
    return 0;
}
  • AC自动机DP的一般形式: dp[length][node] d p [ l e n g t h ] [ n o d e ]

但是有时候字符串长度特别大,我们就需要用矩阵乘法优化DP。

例题二:BZOJ 2553 禁忌

一个字符串的“禁忌伤害”其实就是它的最大不可重匹配。
dp[i][j] d p [ i ] [ j ] 表示长度为 i i 的串匹配到自动机第j个节点的概率, dp[i][tot+1] d p [ i ] [ t o t + 1 ] 用于记录答案。

直接DP不可行,考虑使用矩阵乘法优化。

具体DP转移见代码。

#include <cstdio>
#include <cstring>
#include <queue>
using namespace std;
typedef long double ldb;
const int maxl = 20;
const int maxn = 80;
char s[maxl];
bool vis[maxn];
int n, m, x, alpha;
int tot, nxt[maxn][26], pre[maxn], val[maxn];
ldb dp[maxn][maxn], f[maxn][maxn], t[maxn][maxn];
void prework() {
    tot = 1;
    memset(nxt, 0, sizeof(nxt));
    memset(pre, 0, sizeof(pre));
    memset(val, 0, sizeof(val));
    pre[1] = 1;
}
void insert() {
    int u = 1;
    for (int i = 1; s[i]; i++) {
        if (!nxt[u][s[i] - 'a']) {
            nxt[u][s[i] - 'a'] = ++tot;
        }
        u = nxt[u][s[i] - 'a'];
    }
    val[u] = 1;
}
void getfail() {
    queue<int> que;
    for (int i = 0; i < alpha; i++) {
        if (nxt[1][i]) {
            pre[nxt[1][i]] = 1;
            que.push(nxt[1][i]);
        } else {
            nxt[1][i] = 1;
        }
    }
    for (; !que.empty(); ) {
        int u = que.front();
        val[u] |= val[pre[u]];
        que.pop();
        for (int i = 0; i < alpha; i++) {
            if (nxt[u][i]) {
                pre[nxt[u][i]] = nxt[pre[u]][i];
                que.push(nxt[u][i]);
            } else {
                nxt[u][i] = nxt[pre[u]][i];
            }
        }
    }
}
void getmatrix() {
    ldb beta = 1. / alpha;
    vis[1] = 1;
    queue<int> que;
    que.push(1);
    for (; !que.empty(); ) {
        int u = que.front();
        que.pop();
        for (int i = 0; i < alpha; i++) {
            if (!vis[nxt[u][i]]) {
                vis[nxt[u][i]] = 1;
                que.push(nxt[u][i]);
            } 
            if (val[nxt[u][i]]) {
                f[u][1] += beta;
                f[u][x] += beta;
            } else {
                f[u][nxt[u][i]] += beta;
            }
        }
    }
    f[x][x] = 1;
}
void multiply(ldb a[maxn][maxn], ldb b[maxn][maxn], ldb c[maxn][maxn]) {
    for (int i = 1; i <= x; i++) {
        for (int j = 1; j <= x; j++) {
            t[i][j] = 0;
            for (int k = 1; k <= x; k++) {
                t[i][j] += a[i][k] * b[k][j];
            }
        }
    }
    for (int i = 1; i <= x; i++) {
        for (int j = 1; j <= x; j++) {
            c[i][j] = t[i][j];
        }
    }
}
int main() {
    prework();
    scanf("%d %d %d", &n, &m, &alpha);
    for (int i = 1; i <= n; i++) {
        scanf("%s", s + 1);
        insert();
    }
    getfail();
    x = tot + 1;
    getmatrix();
    for (int i = 1; i <= x; i++) {
        dp[i][i] = 1;
    }
    for (; m; m >>= 1, multiply(f, f, f)) {
        /*
        for (int i = 1; i <= x; i++) {
            for (int j = 1; j <= x; j++) {
                printf("%.8lf%c", (double) f[i][j], j == x ? '\n' : ' ');
            }
        }
        puts("---------------------------");
        */
        if (m & 1) {
            multiply(dp, f, dp);
        }
    }
    printf("%.8lf\n", (double) dp[1][x]);
    return 0;
}
  • 当字符串长度很大时,可以使用矩阵乘法优化

有时,字符串长度为无穷大,此时,我们要么高斯消元(编程复杂度较高),要么迭代(编程复杂度较低)。

例题三:BZOJ 1444 有趣的游戏

DP的具体方法请读者自行思考(或浏览代码)。

我们构造出了转移矩阵,将其自乘 40+ 40 + 次,即可到达题目要求的精度。

#include <cstdio>
#include <cstring>
#include <queue>
using namespace std;
const int maxn = 105;
char s[maxn];
double p[maxn];
int n, m, pos[maxn];
int tot, nxt[maxn][10], pre[maxn], val[maxn];
double mat[maxn][maxn], tmp[maxn][maxn];
void prework() {
    tot = 1;
    memset(nxt, 0, sizeof(nxt));
    memset(pre, 0, sizeof(pre));
    memset(val, 0, sizeof(val));
    pre[1] = 1;
}
int insert() {
    int u = 1;
    for (int i = 1; s[i]; i++) {
        if (!nxt[u][s[i] - 'A']) {
            nxt[u][s[i] - 'A'] = ++tot;
        }
        u = nxt[u][s[i] - 'A'];
    }
    val[u] = 1;
    return u;
}
void getfail() {
    queue<int> que;
    for (int i = 0; i < m; i++) {
        if (nxt[1][i]) {
            pre[nxt[1][i]] = 1;
            que.push(nxt[1][i]);
        } else {
            nxt[1][i] = 1;
        }
    }
    for (; !que.empty(); ) {
        int u = que.front();
        que.pop();
        for (int i = 0; i < m; i++) {
            if (nxt[u][i]) {
                pre[nxt[u][i]] = nxt[pre[u]][i];
                que.push(nxt[u][i]);
            } else {
                nxt[u][i] = nxt[pre[u]][i];
            }
        }
    }
}
void multiply(double a[maxn][maxn], double b[maxn][maxn], double c[maxn][maxn]) {
    for (int i = 1; i <= tot; i++) {
        for (int j = 1; j <= tot; j++) {
            tmp[i][j] = 0;
            for (int k = 1; k <= tot; k++) {
                tmp[i][j] += a[i][k] * b[k][j];
            }
        }
    }
    for (int i = 1; i <= tot; i++) {
        for (int j = 1; j <= tot; j++) {
            c[i][j] = tmp[i][j];
        }
    }
}
int main() {
    scanf("%d %*d %d", &n, &m);
    for (int a, b, i = 0; i < m; i++) {
        scanf("%d %d", &a, &b);
        p[i] = 1. * a / b;
    }
    prework();
    for (int i = 1; i <= n; i++) {
        scanf("%s", s + 1);
        pos[i] = insert();
    }
    getfail();
    for (int i = 1; i <= tot; i++) {
        if (val[i]) {
            mat[i][i] = 1;
        } else {
            for (int j = 0; j < m; j++) {
                mat[i][nxt[i][j]] += p[j];
            }
        }
    }
    for (int i = 0; i < 40; i++) {
        multiply(mat, mat, mat);
    }
    for (int i = 1; i <= n; i++) {
        printf("%.2lf\n", mat[1][pos[i]]);
    }
    return 0;
}

应用三:较高级的字符串匹配

此处就需要用到fail树及其的一些性质了。

这篇Blog讲的很清楚。

  • 重要性质:串S0在串S1中出现了几次 = AC自动机上Root到S1中有多少节点在fail树中S0的子树中。

有了这个性质,我们就可以解决一些更高级的字符串匹配问题了。

例题:BZOJ 2434 阿狸的打字机

建一棵fail树,对于a串,统计子树中有多少个b串的节点即可。
因为子树的节点的dfs序是相连的,所以我们可以用树状数组维护。

注意:此题构图时不能破坏next指针。

#include <cstdio>
#include <cstring>
#include <queue>
#include <vector>
using namespace std;
#define pb push_back
const int maxn = 100005;
char s[maxn];
vector<int> ver[maxn], id[maxn], node[maxn];
int n, m, w[maxn], bit[maxn << 1], ans[maxn];
int cnt, nxt[maxn][26], pre[maxn], lst[maxn];
int cur, lb[maxn], rb[maxn];
void prework() {
    cnt = 1;
    for (int i = 0; i < 26; i++) {
        nxt[0][i] = 1;
    }
}
void insert() {
    int u = 1;
    for (int i = 1; s[i]; i++) {
        if (s[i] == 'B') {
            u = lst[u];
        } else if (s[i] == 'P') {
            w[++n] = u;
        } else {
            if (!nxt[u][s[i] - 'a']) {
                nxt[u][s[i] - 'a'] = ++cnt;
                lst[cnt] = u;
            }
            u = nxt[u][s[i] - 'a'];
        }
    }
}
void getfail() {
    queue<int> que;
    que.push(1);
    while (!que.empty()) {
        int u = que.front();
        que.pop();
        for (int i = 0; i < 26; i++) {
            if (nxt[u][i]) {
                int v = pre[u];
                while (!nxt[v][i]) {
                    v = pre[v];
                }
                pre[nxt[u][i]] = nxt[v][i];
                que.push(nxt[u][i]);
            }
        }
    }
}
void search(int u) {
    lb[u] = ++cur;
    for (int i = 0; i < ver[u].size(); i++) {
        search(ver[u][i]);
    }
    rb[u] = ++cur;
}
void add(int x, int y) {
    for (int i = x; i <= cur; i += i & -i) {
        bit[i] += y;
    }
}
int sum(int x) {
    int res = 0;
    for (int i = x; i; i ^= i & -i) {
        res += bit[i];
    }
    return res;
}
void solve() {
    int u = 1, x = 0;
    add(lb[1], 1);
    for (int i = 1; s[i]; i++) {
        if (s[i] == 'B') {
            add(lb[u], -1);
            u = lst[u];
        } else if (s[i] == 'P') {
            x++;
            for (int j = 0; j < id[x].size(); j++) {
                ans[id[x][j]] = sum(rb[node[x][j]]) - sum(lb[node[x][j]] - 1);
            }
        } else {
            u = nxt[u][s[i] - 'a'];
            add(lb[u], 1);
        }
    }
}
int main() {
    prework();
    scanf("%s", s + 1);
    insert();
    getfail();
    for (int i = 1; i <= cnt; i++) {
        ver[pre[i]].pb(i);
    }
    search(0);
    /*
    for (int i = 1; i <= n; i++) {
        printf("w[%d] = %d\n", i, w[i]);
    }
    for (int i = 1; i <= cnt; i++) {
        printf("%d: fa = %d, son = {%d, %d}, fail = %d, [%d, %d]\n", i, lst[i], nxt[i][0], nxt[i][1], pre[i], lb[i], rb[i]);
    }
    */
    scanf("%d", &m);
    for (int x, y, i = 1; i <= m; i++) {
        scanf("%d %d", &x, &y);
        id[y].pb(i);
        node[y].pb(w[x]);
    }
    solve();
    for (int i = 1; i <= m; i++) {
        printf("%d\n", ans[i]);
    }
    return 0;
}

总结

  • AC自动机 = trie + KMP。
  • AC自动机DP的一般形式: dp[length][node] d p [ l e n g t h ] [ n o d e ]
  • 字符串长度太大时,可以考虑矩阵乘法优化。
  • 长度为无限大时,只需将矩阵自乘多次即可代替高斯消元。
  • 串S0在串S1中出现了几次 = AC自动机上Root到S1中有多少节点在fail树中S0的子树中。

刷题列表

  • 【HNOI 2008】GT考试
  • 【JSOI 2007】 文本生成器
  • 【BZOJ 1212】L语言
  • 【JSOI 2009】 有趣的游戏
  • 【NOI 2011】 阿狸的打字机
  • 【BJOI 2011】 禁忌
  • 【POI 2000】 病毒
  • 【TJOI 2013】单词
  • 【JSOI 2012】玄武密码
  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值