我的博客链接:https://startcraft.cn
概述
AC自动机的目的是实现多模式串的匹配,即在一个主串中查询多个模式串
AC自动机是在trie(字典树)树的基础上实现的
对于一个主串,将这个各个模式串插入字典树,然后进行fail指针的生成fail指针的生成是AC自动机的关键
fail指针顾名思义就是在匹配失败时使用的,在匹配失败时下一步就是去fail指针指向的结点,避免了多次匹配造成的时间浪费
构建字典树
struct NODE
{
int cnt;//字符串出现的次数
int _next[26];//下一个字符
int _fail;//失配指针
};
void node_init(int cnt)//初始化结点
{
memset(node[cnt]._next, -1, sizeof(node[cnt]._next));
node[cnt]._fail = -1;
node[cnt].cnt = 0;
}
插入字符
void insert_char(char *s)
{
int p = 0;
int len = strlen(s);
int i;
for(i=0;i<len;i++)
{
int t = s[i] - 'a';
if (node[p]._next[t] == -1)//该节点不在字符串内
{
node_init(++num);
node[p]._next[t] = num;
}
p = node[p]._next[t];
}
node[p].cnt++;//字符串出现次数++
}
失配指针的计算
对于一个结点C,假设它表示的字符是a,那么他的fail指针的计算就是沿着C的父节点的fail指针走,直到找到一个结点X,X有子节点T表示的字符也是a,那么fail指针就从结点C指向X的子节点T,找不到fail就指向根节点
特别的对于直接与根节点相连的结点
该过程可以用bfs实现
代码:
void cal_fail()
{
queue<NODE>qu;
int i;
for(i=0;i<26;i++)
{
if (node[0]._next[i] != -1)
{
int t = node[0]._next[i];
node[t]._fail = 0;//将与根结点直接相连的结点的fail指针置为0即指向根节点
qu.push(node[t]);
}
}
while (!qu.empty())
{
NODE temp = qu.front();
qu.pop();
for(i=0;i<26;i++)
{
if (temp._next[i] != -1)
{
int p = temp._fail;//父节点的fail指针
while (p != -1 && node[p]._next[i] == -1)
{
p = node[p]._fail;
}
if (p == -1)//找不到
{
node[temp._next[i]]._fail = 0;
} else if (node[p]._next[i] != -1)//找到符合要求的结点
node[temp._next[i]]._fail = node[p]._next[i];
qu.push(node[temp._next[i]]);
}
}
}
}
查询
void query()
{
int len = strlen(ss);//ss为主串
int i;
int p = 0;
for(i=0;i<len;i++)
{
int t = ss[i] - 'a';
while (p != 0 && node[p]._next[t] == -1)//失配时顺着fail指针去查找
p = node[p]._fail;
p = node[p]._next[t];
int temp;
if (p != -1)
temp = p;
else
{
temp = p = 0;//没找到置为0
}
while (temp != 0 && node[temp].cnt != -1)//当前结点是一个模式字符串的结尾
{
ans += node[temp].cnt;
node[temp].cnt = -1;//避免重复计算
temp = node[temp]._fail;
}
}
}
例题:HDU 2222
AC代码:
#include <iostream>
#include <cstdio>
#include <cstring>
#include <queue>
using namespace std;
typedef long long ll;
#define wfor(i,j,k) for(i=j;i<k;++i)
#define mfor(i,j,k) for(i=j;i>=k;--i)
// void read(int &x) {
// char ch = getchar(); x = 0;
// for (; ch < '0' || ch > '9'; ch = getchar());
// for (; ch >= '0' && ch <= '9'; ch = getchar()) x = x * 10 + ch - '0';
// }
const int maxn = 500005;
struct NODE
{
int cnt;
int _next[26];
int _fail;
};
NODE node[maxn];
void node_init(int cnt);
void insert_char(char *s);
void query();
void cal_fail();
char ss[1000005];
int num;
int ans;
int main()
{
std::ios::sync_with_stdio(false);
#ifdef test
freopen("F:\\Desktop\\question\\in.txt", "r", stdin);
#endif
int t;
cin >> t;
while (t--)
{
num = 0;
ans = 0;
node_init(0);
int n;
cin >> n;
char tt[60];
int i;
wfor(i, 0, n)
{
cin >> tt;
insert_char(tt);
}
cal_fail();
cin >> ss;
query();
cout << ans << endl;
}
return 0;
}
void node_init(int cnt)
{
memset(node[cnt]._next, -1, sizeof(node[cnt]._next));
node[cnt]._fail = -1;
node[cnt].cnt = 0;
}
void insert_char(char *s)
{
int p = 0;
int len = strlen(s);
int i;
wfor(i, 0, len)
{
int t = s[i] - 'a';
if (node[p]._next[t] == -1)
{
node_init(++num);
node[p]._next[t] = num;
}
p = node[p]._next[t];
}
node[p].cnt++;
}
void cal_fail()
{
queue<NODE>qu;
int i;
wfor(i, 0, 26)
{
if (node[0]._next[i] != -1)
{
int t = node[0]._next[i];
node[t]._fail = 0;
qu.push(node[t]);
}
}
while (!qu.empty())
{
NODE temp = qu.front();
qu.pop();
wfor(i, 0, 26)
{
if (temp._next[i] != -1)
{
int p = temp._fail;
while (p != -1 && node[p]._next[i] == -1)
{
p = node[p]._fail;
}
if (p == -1)
{
node[temp._next[i]]._fail = 0;
} else if (node[p]._next[i] != -1)
node[temp._next[i]]._fail = node[p]._next[i];
qu.push(node[temp._next[i]]);
}
}
}
}
void query()
{
int len = strlen(ss);
int i;
int p = 0;
wfor(i, 0, len)
{
int t = ss[i] - 'a';
while (p != 0 && node[p]._next[t] == -1)
p = node[p]._fail;
p = node[p]._next[t];
int temp;
if (p != -1)
temp = p;
else
{
temp = p = 0;
}
while (temp != 0 && node[temp].cnt != -1)
{
ans += node[temp].cnt;
node[temp].cnt = -1;
temp = node[temp]._fail;
}
}
}