AC自动机

AC自动机

前言

在字符串匹配的相关问题中很常见的思想就是防止完全回溯, 比如在KMP算法中, 如果到了某个位置发现不匹配, 不会重新开始匹配, 而是跳到下一个匹配的地方, 避免算法复杂度在 O ( m ∗ n ) O(m*n) O(mn), kmp利用next数组让一对字符串的匹配问题的复杂度达到了 O ( m + n ) O(m + n) O(m+n), ac自动机解决的则是多字符串匹配问题, 如下

多字符串匹配问题:

给定字符串S, 和字符串集合 T = { T 1 , T 2 , T 3 , T 4 , T 5 . . . T n } T=\{T_1, T_2, T_3, T_4, T_5...T_n\} T={T1,T2,T3,T4,T5...Tn}

查找集合T中的字符串在S中的出现次数

背景知识

KMP

kmp算法的核心在于next数组

Next数组:

对于str

next[i] 代表str[0:i]的最长匹配前缀是多少

令l=next[i]

也就是str[0:l]=str[i - l + 1:i]

有了next数组后, 当str1和str2在不匹配的时候, 就不会下一次跳到str2的第一个字符去, 而是跳到next数组对应的位置上。

代码如下

class Kmp {
public:
  Kmp(const string& str) : next(str.length(), 0), str(str) {
    // 在构造函数里构造next数组
    int cur = 0;
    for (int i = 1; i < str.length(); ) {
      if (str[i] == str[cur]) {
        next[i++] = ++cur;
      }
      else if (cur != 0)
        cur = next[cur - 1];
      else
        i++;
    }
  }
  
  // 查str在other里出现的位置
  vector<int> GetOccPos(const string& other) {
    int cur = 0;
    vector<int> res;
    for (int i = 0; i < other.length(); ) {
      if (other[i] == str[cur]) {
        i++,cur++;
        if (cur == str.length()) {
          res.push_back(i - str.length());
          cur = next[cur - 1];
        }
      }
      else if (cur != 0)
        cur = next[cur - 1];
      else
        i++;
    }
    return res;
  }
  
  vector<int> GetNextArray() { return next; }
  
private:
  vector<int> next;
  string str;
};

Trie

Trie树是一种字符串查找的树, 每个节点代表一个字符, 子节点是接下来的字符

以这张图为例

插入字符串code, cook, five, file, fat 得到以下的一棵树

image-20201223124041245

trie的代码参考如下

class Trie {
public:
    /** Initialize your data structure here. */
    Trie() {
        tree_.push_back(TrieNode());
    }
    
    /** Inserts a word into the trie. */
    void insert(string word) {
        int cur = 0;
        for (auto c : word) {
            if (!tree_[cur].next[c - 'a']) {
                tree_.push_back(TrieNode());
                tree_[cur].next[c - 'a'] = tree_.size() - 1;
            }
            cur = tree_[cur].next[c - 'a'];
        }
        tree_[cur].is_end = true;
    }
    
    /** Returns if the word is in the trie. */
    bool search(string word) {
        int cur = 0;
        for (auto c : word) {
            if (!tree_[cur].next[c - 'a'])
                return false;
            else
                cur = tree_[cur].next[c - 'a'];
        }
        return tree_[cur].is_end;
    }
    
    /** Returns if there is any word in the trie that starts with the given prefix. */
    bool startsWith(string prefix) {
        int cur = 0;
        for (auto c : prefix) {
            if (!tree_[cur].next[c - 'a'])
                return false;
            else 
                cur = tree_[cur].next[c - 'a'];
        }
        return true;
    }
    
private:
    struct TrieNode {
        int next[26];
        bool is_end;
    };
    
    vector<TrieNode> tree_;
};

/**
 * Your Trie object will be instantiated and called as such:
 * Trie* obj = new Trie();
 * obj->insert(word);
 * bool param_2 = obj->search(word);
 * bool param_3 = obj->startsWith(prefix);
 */

AC自动机算法

AC自动机是对于trie的改进, 加入了fail指针, 解决多字符串匹配问题

struct AcNode {
  int next[26];
  int fail;
  bool is_end;
}

其中fail指针指向的是当前字符串的最长后缀

画了个示意图, 其中红线就是fail指针

对于abc来说, 出现在ac自动机的最长后缀是bc因此abc指向bc, 对于ab来说 最长后缀是b, 因此ab指向b,这张图比较形象地展示了ac自动机的工作方式

image-20201224084437252

下面看代码

#include <queue>
#include <cstdlib>
#include <cmath>
#include <cstdio>
#include <string>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
typedef long long ll;
const int maxn =  2*1e6+9;

int trie[maxn][26]; //字典树
int cntword[maxn];  //记录该单词出现次数
int fail[maxn];     //失败时的回溯指针
int cnt = 0;

void insertWords(string s){
    int root = 0;
    for(int i=0;i<s.size();i++){
        int next = s[i] - 'a';
        if(!trie[root][next])
            trie[root][next] = ++cnt;
        root = trie[root][next];
    }
    cntword[root]++;      //当前节点单词数+1
}
void getFail(){
    queue <int>q;
    for(int i=0;i<26;i++){      //将第二层所有出现了的字母扔进队列
        if(trie[0][i]){
            fail[trie[0][i]] = 0;
            q.push(trie[0][i]);
        }
    }

//fail[now]    ->当前节点now的失败指针指向的地方
tire[now][i] -> 下一个字母为i+'a'的节点的下标为tire[now][i]
    while(!q.empty()){
        int now = q.front();
        q.pop();

        for(int i=0;i<26;i++){      //查询26个字母
            if(trie[now][i]){
                //如果有这个子节点为字母i+'a',则
//让这个节点的失败指针指向(((他父亲节点)的失败指针所指向的那个节点)的下一个节点)
                //有点绕,为了方便理解特意加了括号

                fail[trie[now][i]] = trie[fail[now]][i];
                q.push(trie[now][i]);
            }
            else//否则就让当前节点的这个子节点
                //指向当前节点fail指针的这个子节点
                trie[now][i] = trie[fail[now]][i];
        }
    }
}


int query(string s){
    int now = 0,ans = 0;
    for(int i=0;i<s.size();i++){    //遍历文本串
        now = trie[now][s[i]-'a'];  //从s[i]点开始寻找
        for(int j=now;j && cntword[j]!=-1;j=fail[j]){
            //一直向下寻找,直到匹配失败(失败指针指向根或者当前节点已找过).
            ans += cntword[j];
            cntword[j] = -1;    //将遍历国后的节点标记,防止重复计算
        }
    }
    return ans;
}

int main() {
    int n;
    string s;
    cin >> n;
    for(int i=0;i<n;i++){
        cin >> s ;
        insertWords(s);
    }
    fail[0] = 0;
    getFail();
    cin >> s ;
    cout << query(s) << endl;
    return 0;
}


常见题目和思想

带子串包含约束的lcs问题

image-20201224085514314

思路:

把ac自动机的结点当做dp的状态 然后top downdp

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

inline int read() {
	char c = getchar(); int x = 0, f = 1;
	while (c < '0' || c > '9') { if (c == '-') f = -1; c = getchar(); }
	while (c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
	return x * f;
}

const int N = 1805;

int n, m, k, ch[N][52], tot, fail[N], ed[N], ans;
int tsx[305][52], tsy[305][52], l[N], lst[52];
char strx[305], stry[305], str[305];

struct node {
	int px, py, u, mask;
	node(int cpx = 0, int cpy = 0, int cu = 0, int cmask = 0) {
		px = cpx, py = cpy, u = cu, mask = cmask;
	}
	bool operator < (const node &c) const {
		if (px != c.px) return px < c.px;
		if (py != c.py) return py < c.py;
		if (u != c.u) return u < c.u;
		return mask < c.mask;
	}
};
map<node, int> stp; map<node, bool> inq;

int Type(char c) {
	if (c >= 'a' && c <= 'z') return c - 'a';
	return c - 'A' + 26;
}

#define SS ch[u][i]
void BFS() {
	queue<int> que; while (!que.empty()) que.pop();
	for (int i = 0; i < 52; i++) if (ch[0][i]) que.push(ch[0][i]);
	while (!que.empty()) {
		int u = que.front(); que.pop();
		for (int i = 0; i < 52; i++)
			if (!SS) SS = ch[fail[u]][i];
			else que.push(SS), fail[SS] = ch[fail[u]][i], ed[SS] |= ed[fail[SS]];
	}
}

void DP() {
	queue<node> que; while (!que.empty()) que.pop();
	stp[node(0, 0, 0, 0)] = 1; que.push(node(0, 0, 0, 0)), inq[node(0, 0, 0, 0)] = true;
	while (!que.empty()) {
		node now = que.front(); que.pop(); inq[now] = false;
		int ntp = stp[now];
		int px = now.px, py = now.py, u = now.u, mask = now.mask;
		if (mask == (1 << k) - 1) ans = max(ans, ntp);
		for (int i = 0; i < 52; i++) {
			int nx = tsx[px][i], ny = tsy[py][i], v = ch[u][i], nsk = mask | ed[v];
			if (nx > n || ny > m) continue; node to = node(nx, ny, v, nsk);
			if (!stp[to] || stp[to] < ntp + 1) {
				stp[to] = ntp + 1;
				if (!inq[to]) que.push(to);
			}
		}
	}
}

int main() {
	n = read(); m = read(); k = read();
	for (int i = 1; i <= k; i++) l[i] = read();
	scanf("%s", strx + 1), scanf("%s", stry + 1);
	for (int i = 0; i < 52; i++) lst[i] = n + 1;
	for (int i = n; i >= 0; i--) {
		for (int j = 0; j < 52; j++) tsx[i][j] = lst[j];
		lst[Type(strx[i])] = i;
	}
	for (int i = 0; i < 52; i++) lst[i] = m + 1;
	for (int i = m; i >= 0; i--) {
		for (int j = 0; j < 52; j++) tsy[i][j] = lst[j];
		lst[Type(stry[i])] = i;
	}
	for (int i = 1; i <= k; i++) {
		scanf("%s", str + 1); int u = 0;
		for (int j = 1; j <= l[i]; j++) {
			int p = Type(str[j]);
			if (!ch[u][p]) ch[u][p] = ++tot;
			u = ch[u][p];
		}
		ed[u] |= (1 << i - 1);
	}
	BFS(), DP();
	printf("%d\n", ans - 1);
	return 0;
}
  1. 用fail指针反向建树, 树状数组对节点统计
  2. 把往AC自动机里添加变成一步步删除
  3. 把AC自动机的状态作为DP的一个维度
#include <iostream>
#include <algorithm>
#include <unordered_map>
#include <memory>
#include <string>
#include <vector>
#include <stack>
#include <queue>
#include <map>
using namespace std;

namespace zzy {
namespace algo{
 
class AcTree {
  struct AcNode {
    int pass_cnt, end, fail, tmp_end;
    int to[26];
  };
  
public:
  AcTree() : ac_tree_(1, AcNode{0, 0, 0, 0}) {
  }
  
  AcTree(int n) : ac_tree_(n, AcNode{0, 0, 0, 0}) {
    
  }
  
  virtual ~AcTree() { }
  
  void Insert(const string& str) {
    int cur = 0;
    for (auto c : str) {
      if (!ac_tree_[cur].to[c - 'a']) {
        ac_tree_[cur].to[c - 'a'] = ac_tree_.size();
        ac_tree_.push_back(AcNode{0, 0, 0});
      }
      cur = ac_tree_[cur].to[c - 'a'];
      ac_tree_[cur].pass_cnt++;
    }
    ac_tree_[cur].end++;
  }
  
  void Remove(const string& str) {
    int cur = 0;
    for (auto c : str) {
      cur = ac_tree_[cur].to[c - 'a'];
      ac_tree_[cur].pass_cnt--;
    }
  }
  
  void BuildAc() {
    queue<int> q;
    for (int i = 0; i < 26; i++)
      if (ac_tree_[0].to[i])
        q.push(ac_tree_[0].to[i]);
    while (!q.empty()) {
      int cur = q.front();
      q.pop();
      for (int i = 0; i < 26; i++) {
        if (ac_tree_[cur].to[i]) {
          int next = ac_tree_[cur].to[i];
          ac_tree_[next].fail = ac_tree_[ac_tree_[cur].fail].to[i];
          q.push(next);
        }
        else
          ac_tree_[cur].to[i] = ac_tree_[ac_tree_[cur].fail].to[i];
      }
    }
  }
  
  int GetOccurenceOneTime(const string& str) {
    for (int i = 0; i < ac_tree_.size(); i++)
      ac_tree_[i].tmp_end = ac_tree_[i].end;
    int cur = 0, res = 0;
    for (auto c : str) {
      cur = ac_tree_[cur].to[c - 'a'];
      for (int i = cur; i && ac_tree_[i].tmp_end != -1; i = ac_tree_[i].fail) {
        if (ac_tree_[i].pass_cnt)
          res += ac_tree_[i].tmp_end;
        ac_tree_[i].tmp_end = -1;
      }
    }
    return res;
  }
  
  int GetOccurenceMoreTime(const string& str) {
    int cur = 0, res = 0;
    for (auto c : str) {
      cur = ac_tree_[cur].to[c - 'a'];
      for (int i = cur; i; i = ac_tree_[i].fail) {
        if (ac_tree_[i].pass_cnt)
          res += ac_tree_[i].end;
     }
    }
    return res;
  }
  
private:
  vector<AcNode> ac_tree_;
  
};
}
}

文本生成器

image-20201224085708065

思路:

ac自动机+dp

#include <iostream>
#include <algorithm>
#include <unordered_map>
#include <memory>
#include <string>
#include <vector>
#include <stack>
#include <queue>
#include <map>
#include <string.h>
#include <cstdlib>
using namespace std;

#define MAX_M 200
#define MAX_STATE 10005
typedef long long int ll;

ll kMod = 10007;

ll n, m;
ll dp[MAX_M][MAX_STATE];

class AcTree {
public:
  ll trie[MAX_STATE][26];
  ll fail[MAX_STATE];
  bool end[MAX_STATE];
  ll cnt;

  void Insert(const string& str) {
    ll cur = 0;
    for (auto c : str) {
      if (trie[cur][c - 'A'])
        cur = trie[cur][c - 'A'];
      else {
        trie[cur][c - 'A'] = ++cnt;
        cur = cnt;
      }
    }
    end[cur] = true;
  }
  
  void Build() {
    queue<ll> q;
    for (int i = 0; i < 26; i++)
      if (trie[0][i])
        q.push(trie[0][i]);
    while (!q.empty()) {
      ll t = q.front();
      q.pop();
      end[t] |= end[fail[t]];
      for (int i = 0;i < 26; i++) {
        if (trie[t][i]) {
          fail[trie[t][i]] = trie[fail[t]][i];
          q.push(trie[t][i]);
        }
        else
          trie[t][i] = trie[fail[t]][i];
      }
    }
  }
  
} ac_tree;

ll PowMod(ll x, ll y) {
  ll cur = x, res = 1;
  while (y) {
    if (y & 1)
      res = (res * cur) % kMod;
    cur = (cur * cur) % kMod;
    y >>= 1;
  }
  return res;
}

void read() {
  cin >> n >> m;
  for (int i = 0; i < n; i++) {
    string tmp;
    cin >> tmp;
    ac_tree.Insert(tmp);
  }
  ac_tree.Build();
}

void solve() {
  dp[0][0] = 1;
  ll tot = ac_tree.cnt;
  for (int i = 1; i <= m; i++) {
    for (int j = 0; j <= tot; j++) {
      for (int k = 0; k < 26; k++) {
        if (!ac_tree.end[ac_tree.trie[j][k]]) {
          ll nxt = ac_tree.trie[j][k];
          dp[i][nxt] = (dp[i][nxt] + dp[i - 1][j]) % kMod;
        }
      }
    }
  }
  ll sum = 0;
  for (int i = 0; i <= tot; i++)
    sum = (sum + dp[m][i]) % kMod;
  cout << (PowMod(26, m) + kMod - sum) % kMod << endl;
}

int main() {
  ios::sync_with_stdio(false);
  cin.tie(NULL);
  read();
  solve();
}

阿狸的打字机

image-20201224085933639

思路:

fail反向建树, dfs获得顺序, 然后树状数组保存, 看串和串之间的包含关系

#include <cstdio>
#include <cstdlib>
#include <cstring>
#define N 100000
 
struct edge
{
    int next, node;
};
struct map
{
    int head[N + 1], tot;
    edge e[N + 1];
    map() { tot = 0; }
    inline void addedge(int x, int y)
    {
        e[++tot].next = head[x];
        e[tot].node = y;
        head[x] = tot;
    }
}fail, queries;
//Forward stars, one for fail pointer tree and one for queries
const int root = 1;
struct node
{
    int f, son[26], fail;
}t[N + 1];
//Trie tree
char s[N + 1];
//Entry string
int len, tot = root, strs = 0, m;
int strNode[N + 1];
//Corresponding trie node for a string
int queue[N + 1];
int head[N + 1], tail[N + 1], count = 0;
//DFS sequence
int arr[N + 1];
//Binary indexed tree, storing how many nodes of string y are in a certain node and its subtrees
int ans[N + 1];
//Stores the answer to each of the queries
 
inline int query(int x)
{
    int ret = 0;
    for (; x; x -= x & (-x))
        ret += arr[x];
    return ret;
}
 
inline void modify(int x, int d)
{
    for (; x <= count; x += x & (-x))
        arr[x] += d;
}
 
void dfs(int x)
{
    head[x] = ++count;
    for (int i = fail.head[x]; i; i = fail.e[i].next)
        dfs(fail.e[i].node);
    tail[x] = count;
}
 
int main()
{
    scanf("%s", s);
    len = strlen(s);
    int now = root;
    //Read queries and build queries graph
    scanf("%d", &m);
    for (int x, y, i = 0; i < m; ++i)
    {
        scanf("%d%d", &x, &y);
        queries.addedge(y, x);
    }
    //Build trie tree
    for (int i = 0; i < len; ++i)
    {
        if (s[i] == 'P') strNode[++strs] = now;
        else if (s[i] == 'B') now = t[now].f;
        else
        {
            if (t[now].son[s[i] - 'a']) now = t[now].son[s[i] - 'a'];
            else
            {
                int cur = now;
                t[now = t[now].son[s[i] - 'a'] = ++tot].f = cur;
            }
        }
    }
    //Construct fail pointers and build fail pointer tree
    int l = 0, r = -1;
    for (int i = 0; i < 26; ++i)
    {
        if (t[root].son[i])
        {
            queue[++r] = t[root].son[i];
            fail.addedge(root, queue[r]);
            t[queue[r]].fail = root;
        }
    }
    for (; l <= r; ++l)
    {
        for (int i = 0; i < 26; ++i)
            if (t[queue[l]].son[i])
            {
                queue[++r] = t[queue[l]].son[i];
                for (now = t[queue[l]].fail; now != root && !t[now].son[i]; now = t[now].fail) ;
                t[queue[r]].fail = t[now].son[i] ? t[now].son[i] : root;
                fail.addedge(t[queue[r]].fail, queue[r]);
            }
    }
    //Construct DFS sequence of fail pointer tree
    dfs(root);
    //Traverse through trie tree while enumerating y string
    now = root, strs = 0;
    for (int qs, i = 0; i < len; ++i)
    {
        if (s[i] == 'B')
        {
            modify(head[now], -1);
            now = t[now].f;
        }
        else if (s[i] != 'P')
        {
            now = t[now].son[s[i] - 'a'];
            modify(head[now], 1);
        }
        else
        {
            for (int x = queries.head[++strs]; x; x = queries.e[x].next)
                ans[x] = query(tail[strNode[queries.e[x].node]]) - query(head[strNode[queries.e[x].node]] - 1);
        }
    }
    for (int i = 1; i <= m; ++i)
        printf("%d\n", ans[i]);
    return 0;
}
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值