AC自动机
前言
在字符串匹配的相关问题中很常见的思想就是防止完全回溯, 比如在KMP算法中, 如果到了某个位置发现不匹配, 不会重新开始匹配, 而是跳到下一个匹配的地方, 避免算法复杂度在 O ( m ∗ n ) O(m*n) O(m∗n), kmp利用next数组让一对字符串的匹配问题的复杂度达到了 O ( m + n ) O(m + n) O(m+n), ac自动机解决的则是多字符串匹配问题, 如下
多字符串匹配问题:
给定字符串S, 和字符串集合 T = { T 1 , T 2 , T 3 , T 4 , T 5 . . . T n } T=\{T_1, T_2, T_3, T_4, T_5...T_n\} T={T1,T2,T3,T4,T5...Tn}
查找集合T中的字符串在S中的出现次数
背景知识
KMP
kmp算法的核心在于next数组
Next数组:
对于str
next[i] 代表str[0:i]的最长匹配前缀是多少
令l=next[i]
也就是str[0:l]=str[i - l + 1:i]
有了next数组后, 当str1和str2在不匹配的时候, 就不会下一次跳到str2的第一个字符去, 而是跳到next数组对应的位置上。
代码如下
class Kmp {
public:
Kmp(const string& str) : next(str.length(), 0), str(str) {
// 在构造函数里构造next数组
int cur = 0;
for (int i = 1; i < str.length(); ) {
if (str[i] == str[cur]) {
next[i++] = ++cur;
}
else if (cur != 0)
cur = next[cur - 1];
else
i++;
}
}
// 查str在other里出现的位置
vector<int> GetOccPos(const string& other) {
int cur = 0;
vector<int> res;
for (int i = 0; i < other.length(); ) {
if (other[i] == str[cur]) {
i++,cur++;
if (cur == str.length()) {
res.push_back(i - str.length());
cur = next[cur - 1];
}
}
else if (cur != 0)
cur = next[cur - 1];
else
i++;
}
return res;
}
vector<int> GetNextArray() { return next; }
private:
vector<int> next;
string str;
};
Trie
Trie树是一种字符串查找的树, 每个节点代表一个字符, 子节点是接下来的字符
以这张图为例
插入字符串code, cook, five, file, fat 得到以下的一棵树
trie的代码参考如下
class Trie {
public:
/** Initialize your data structure here. */
Trie() {
tree_.push_back(TrieNode());
}
/** Inserts a word into the trie. */
void insert(string word) {
int cur = 0;
for (auto c : word) {
if (!tree_[cur].next[c - 'a']) {
tree_.push_back(TrieNode());
tree_[cur].next[c - 'a'] = tree_.size() - 1;
}
cur = tree_[cur].next[c - 'a'];
}
tree_[cur].is_end = true;
}
/** Returns if the word is in the trie. */
bool search(string word) {
int cur = 0;
for (auto c : word) {
if (!tree_[cur].next[c - 'a'])
return false;
else
cur = tree_[cur].next[c - 'a'];
}
return tree_[cur].is_end;
}
/** Returns if there is any word in the trie that starts with the given prefix. */
bool startsWith(string prefix) {
int cur = 0;
for (auto c : prefix) {
if (!tree_[cur].next[c - 'a'])
return false;
else
cur = tree_[cur].next[c - 'a'];
}
return true;
}
private:
struct TrieNode {
int next[26];
bool is_end;
};
vector<TrieNode> tree_;
};
/**
* Your Trie object will be instantiated and called as such:
* Trie* obj = new Trie();
* obj->insert(word);
* bool param_2 = obj->search(word);
* bool param_3 = obj->startsWith(prefix);
*/
AC自动机算法
AC自动机是对于trie的改进, 加入了fail指针, 解决多字符串匹配问题
struct AcNode {
int next[26];
int fail;
bool is_end;
}
其中fail指针指向的是当前字符串的最长后缀
画了个示意图, 其中红线就是fail指针
对于abc来说, 出现在ac自动机的最长后缀是bc因此abc指向bc, 对于ab来说 最长后缀是b, 因此ab指向b,这张图比较形象地展示了ac自动机的工作方式
下面看代码
#include <queue>
#include <cstdlib>
#include <cmath>
#include <cstdio>
#include <string>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
typedef long long ll;
const int maxn = 2*1e6+9;
int trie[maxn][26]; //字典树
int cntword[maxn]; //记录该单词出现次数
int fail[maxn]; //失败时的回溯指针
int cnt = 0;
void insertWords(string s){
int root = 0;
for(int i=0;i<s.size();i++){
int next = s[i] - 'a';
if(!trie[root][next])
trie[root][next] = ++cnt;
root = trie[root][next];
}
cntword[root]++; //当前节点单词数+1
}
void getFail(){
queue <int>q;
for(int i=0;i<26;i++){ //将第二层所有出现了的字母扔进队列
if(trie[0][i]){
fail[trie[0][i]] = 0;
q.push(trie[0][i]);
}
}
//fail[now] ->当前节点now的失败指针指向的地方
tire[now][i] -> 下一个字母为i+'a'的节点的下标为tire[now][i]
while(!q.empty()){
int now = q.front();
q.pop();
for(int i=0;i<26;i++){ //查询26个字母
if(trie[now][i]){
//如果有这个子节点为字母i+'a',则
//让这个节点的失败指针指向(((他父亲节点)的失败指针所指向的那个节点)的下一个节点)
//有点绕,为了方便理解特意加了括号
fail[trie[now][i]] = trie[fail[now]][i];
q.push(trie[now][i]);
}
else//否则就让当前节点的这个子节点
//指向当前节点fail指针的这个子节点
trie[now][i] = trie[fail[now]][i];
}
}
}
int query(string s){
int now = 0,ans = 0;
for(int i=0;i<s.size();i++){ //遍历文本串
now = trie[now][s[i]-'a']; //从s[i]点开始寻找
for(int j=now;j && cntword[j]!=-1;j=fail[j]){
//一直向下寻找,直到匹配失败(失败指针指向根或者当前节点已找过).
ans += cntword[j];
cntword[j] = -1; //将遍历国后的节点标记,防止重复计算
}
}
return ans;
}
int main() {
int n;
string s;
cin >> n;
for(int i=0;i<n;i++){
cin >> s ;
insertWords(s);
}
fail[0] = 0;
getFail();
cin >> s ;
cout << query(s) << endl;
return 0;
}
常见题目和思想
带子串包含约束的lcs问题
思路:
把ac自动机的结点当做dp的状态 然后top downdp
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
inline int read() {
char c = getchar(); int x = 0, f = 1;
while (c < '0' || c > '9') { if (c == '-') f = -1; c = getchar(); }
while (c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
return x * f;
}
const int N = 1805;
int n, m, k, ch[N][52], tot, fail[N], ed[N], ans;
int tsx[305][52], tsy[305][52], l[N], lst[52];
char strx[305], stry[305], str[305];
struct node {
int px, py, u, mask;
node(int cpx = 0, int cpy = 0, int cu = 0, int cmask = 0) {
px = cpx, py = cpy, u = cu, mask = cmask;
}
bool operator < (const node &c) const {
if (px != c.px) return px < c.px;
if (py != c.py) return py < c.py;
if (u != c.u) return u < c.u;
return mask < c.mask;
}
};
map<node, int> stp; map<node, bool> inq;
int Type(char c) {
if (c >= 'a' && c <= 'z') return c - 'a';
return c - 'A' + 26;
}
#define SS ch[u][i]
void BFS() {
queue<int> que; while (!que.empty()) que.pop();
for (int i = 0; i < 52; i++) if (ch[0][i]) que.push(ch[0][i]);
while (!que.empty()) {
int u = que.front(); que.pop();
for (int i = 0; i < 52; i++)
if (!SS) SS = ch[fail[u]][i];
else que.push(SS), fail[SS] = ch[fail[u]][i], ed[SS] |= ed[fail[SS]];
}
}
void DP() {
queue<node> que; while (!que.empty()) que.pop();
stp[node(0, 0, 0, 0)] = 1; que.push(node(0, 0, 0, 0)), inq[node(0, 0, 0, 0)] = true;
while (!que.empty()) {
node now = que.front(); que.pop(); inq[now] = false;
int ntp = stp[now];
int px = now.px, py = now.py, u = now.u, mask = now.mask;
if (mask == (1 << k) - 1) ans = max(ans, ntp);
for (int i = 0; i < 52; i++) {
int nx = tsx[px][i], ny = tsy[py][i], v = ch[u][i], nsk = mask | ed[v];
if (nx > n || ny > m) continue; node to = node(nx, ny, v, nsk);
if (!stp[to] || stp[to] < ntp + 1) {
stp[to] = ntp + 1;
if (!inq[to]) que.push(to);
}
}
}
}
int main() {
n = read(); m = read(); k = read();
for (int i = 1; i <= k; i++) l[i] = read();
scanf("%s", strx + 1), scanf("%s", stry + 1);
for (int i = 0; i < 52; i++) lst[i] = n + 1;
for (int i = n; i >= 0; i--) {
for (int j = 0; j < 52; j++) tsx[i][j] = lst[j];
lst[Type(strx[i])] = i;
}
for (int i = 0; i < 52; i++) lst[i] = m + 1;
for (int i = m; i >= 0; i--) {
for (int j = 0; j < 52; j++) tsy[i][j] = lst[j];
lst[Type(stry[i])] = i;
}
for (int i = 1; i <= k; i++) {
scanf("%s", str + 1); int u = 0;
for (int j = 1; j <= l[i]; j++) {
int p = Type(str[j]);
if (!ch[u][p]) ch[u][p] = ++tot;
u = ch[u][p];
}
ed[u] |= (1 << i - 1);
}
BFS(), DP();
printf("%d\n", ans - 1);
return 0;
}
- 用fail指针反向建树, 树状数组对节点统计
- 把往AC自动机里添加变成一步步删除
- 把AC自动机的状态作为DP的一个维度
#include <iostream>
#include <algorithm>
#include <unordered_map>
#include <memory>
#include <string>
#include <vector>
#include <stack>
#include <queue>
#include <map>
using namespace std;
namespace zzy {
namespace algo{
class AcTree {
struct AcNode {
int pass_cnt, end, fail, tmp_end;
int to[26];
};
public:
AcTree() : ac_tree_(1, AcNode{0, 0, 0, 0}) {
}
AcTree(int n) : ac_tree_(n, AcNode{0, 0, 0, 0}) {
}
virtual ~AcTree() { }
void Insert(const string& str) {
int cur = 0;
for (auto c : str) {
if (!ac_tree_[cur].to[c - 'a']) {
ac_tree_[cur].to[c - 'a'] = ac_tree_.size();
ac_tree_.push_back(AcNode{0, 0, 0});
}
cur = ac_tree_[cur].to[c - 'a'];
ac_tree_[cur].pass_cnt++;
}
ac_tree_[cur].end++;
}
void Remove(const string& str) {
int cur = 0;
for (auto c : str) {
cur = ac_tree_[cur].to[c - 'a'];
ac_tree_[cur].pass_cnt--;
}
}
void BuildAc() {
queue<int> q;
for (int i = 0; i < 26; i++)
if (ac_tree_[0].to[i])
q.push(ac_tree_[0].to[i]);
while (!q.empty()) {
int cur = q.front();
q.pop();
for (int i = 0; i < 26; i++) {
if (ac_tree_[cur].to[i]) {
int next = ac_tree_[cur].to[i];
ac_tree_[next].fail = ac_tree_[ac_tree_[cur].fail].to[i];
q.push(next);
}
else
ac_tree_[cur].to[i] = ac_tree_[ac_tree_[cur].fail].to[i];
}
}
}
int GetOccurenceOneTime(const string& str) {
for (int i = 0; i < ac_tree_.size(); i++)
ac_tree_[i].tmp_end = ac_tree_[i].end;
int cur = 0, res = 0;
for (auto c : str) {
cur = ac_tree_[cur].to[c - 'a'];
for (int i = cur; i && ac_tree_[i].tmp_end != -1; i = ac_tree_[i].fail) {
if (ac_tree_[i].pass_cnt)
res += ac_tree_[i].tmp_end;
ac_tree_[i].tmp_end = -1;
}
}
return res;
}
int GetOccurenceMoreTime(const string& str) {
int cur = 0, res = 0;
for (auto c : str) {
cur = ac_tree_[cur].to[c - 'a'];
for (int i = cur; i; i = ac_tree_[i].fail) {
if (ac_tree_[i].pass_cnt)
res += ac_tree_[i].end;
}
}
return res;
}
private:
vector<AcNode> ac_tree_;
};
}
}
文本生成器
思路:
ac自动机+dp
#include <iostream>
#include <algorithm>
#include <unordered_map>
#include <memory>
#include <string>
#include <vector>
#include <stack>
#include <queue>
#include <map>
#include <string.h>
#include <cstdlib>
using namespace std;
#define MAX_M 200
#define MAX_STATE 10005
typedef long long int ll;
ll kMod = 10007;
ll n, m;
ll dp[MAX_M][MAX_STATE];
class AcTree {
public:
ll trie[MAX_STATE][26];
ll fail[MAX_STATE];
bool end[MAX_STATE];
ll cnt;
void Insert(const string& str) {
ll cur = 0;
for (auto c : str) {
if (trie[cur][c - 'A'])
cur = trie[cur][c - 'A'];
else {
trie[cur][c - 'A'] = ++cnt;
cur = cnt;
}
}
end[cur] = true;
}
void Build() {
queue<ll> q;
for (int i = 0; i < 26; i++)
if (trie[0][i])
q.push(trie[0][i]);
while (!q.empty()) {
ll t = q.front();
q.pop();
end[t] |= end[fail[t]];
for (int i = 0;i < 26; i++) {
if (trie[t][i]) {
fail[trie[t][i]] = trie[fail[t]][i];
q.push(trie[t][i]);
}
else
trie[t][i] = trie[fail[t]][i];
}
}
}
} ac_tree;
ll PowMod(ll x, ll y) {
ll cur = x, res = 1;
while (y) {
if (y & 1)
res = (res * cur) % kMod;
cur = (cur * cur) % kMod;
y >>= 1;
}
return res;
}
void read() {
cin >> n >> m;
for (int i = 0; i < n; i++) {
string tmp;
cin >> tmp;
ac_tree.Insert(tmp);
}
ac_tree.Build();
}
void solve() {
dp[0][0] = 1;
ll tot = ac_tree.cnt;
for (int i = 1; i <= m; i++) {
for (int j = 0; j <= tot; j++) {
for (int k = 0; k < 26; k++) {
if (!ac_tree.end[ac_tree.trie[j][k]]) {
ll nxt = ac_tree.trie[j][k];
dp[i][nxt] = (dp[i][nxt] + dp[i - 1][j]) % kMod;
}
}
}
}
ll sum = 0;
for (int i = 0; i <= tot; i++)
sum = (sum + dp[m][i]) % kMod;
cout << (PowMod(26, m) + kMod - sum) % kMod << endl;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(NULL);
read();
solve();
}
阿狸的打字机
思路:
fail反向建树, dfs获得顺序, 然后树状数组保存, 看串和串之间的包含关系
#include <cstdio>
#include <cstdlib>
#include <cstring>
#define N 100000
struct edge
{
int next, node;
};
struct map
{
int head[N + 1], tot;
edge e[N + 1];
map() { tot = 0; }
inline void addedge(int x, int y)
{
e[++tot].next = head[x];
e[tot].node = y;
head[x] = tot;
}
}fail, queries;
//Forward stars, one for fail pointer tree and one for queries
const int root = 1;
struct node
{
int f, son[26], fail;
}t[N + 1];
//Trie tree
char s[N + 1];
//Entry string
int len, tot = root, strs = 0, m;
int strNode[N + 1];
//Corresponding trie node for a string
int queue[N + 1];
int head[N + 1], tail[N + 1], count = 0;
//DFS sequence
int arr[N + 1];
//Binary indexed tree, storing how many nodes of string y are in a certain node and its subtrees
int ans[N + 1];
//Stores the answer to each of the queries
inline int query(int x)
{
int ret = 0;
for (; x; x -= x & (-x))
ret += arr[x];
return ret;
}
inline void modify(int x, int d)
{
for (; x <= count; x += x & (-x))
arr[x] += d;
}
void dfs(int x)
{
head[x] = ++count;
for (int i = fail.head[x]; i; i = fail.e[i].next)
dfs(fail.e[i].node);
tail[x] = count;
}
int main()
{
scanf("%s", s);
len = strlen(s);
int now = root;
//Read queries and build queries graph
scanf("%d", &m);
for (int x, y, i = 0; i < m; ++i)
{
scanf("%d%d", &x, &y);
queries.addedge(y, x);
}
//Build trie tree
for (int i = 0; i < len; ++i)
{
if (s[i] == 'P') strNode[++strs] = now;
else if (s[i] == 'B') now = t[now].f;
else
{
if (t[now].son[s[i] - 'a']) now = t[now].son[s[i] - 'a'];
else
{
int cur = now;
t[now = t[now].son[s[i] - 'a'] = ++tot].f = cur;
}
}
}
//Construct fail pointers and build fail pointer tree
int l = 0, r = -1;
for (int i = 0; i < 26; ++i)
{
if (t[root].son[i])
{
queue[++r] = t[root].son[i];
fail.addedge(root, queue[r]);
t[queue[r]].fail = root;
}
}
for (; l <= r; ++l)
{
for (int i = 0; i < 26; ++i)
if (t[queue[l]].son[i])
{
queue[++r] = t[queue[l]].son[i];
for (now = t[queue[l]].fail; now != root && !t[now].son[i]; now = t[now].fail) ;
t[queue[r]].fail = t[now].son[i] ? t[now].son[i] : root;
fail.addedge(t[queue[r]].fail, queue[r]);
}
}
//Construct DFS sequence of fail pointer tree
dfs(root);
//Traverse through trie tree while enumerating y string
now = root, strs = 0;
for (int qs, i = 0; i < len; ++i)
{
if (s[i] == 'B')
{
modify(head[now], -1);
now = t[now].f;
}
else if (s[i] != 'P')
{
now = t[now].son[s[i] - 'a'];
modify(head[now], 1);
}
else
{
for (int x = queries.head[++strs]; x; x = queries.e[x].next)
ans[x] = query(tail[strNode[queries.e[x].node]]) - query(head[strNode[queries.e[x].node]] - 1);
}
}
for (int i = 1; i <= m; ++i)
printf("%d\n", ans[i]);
return 0;
}