(好好的ac自动机让我做成了tle自动机orz。。。)
题目:http://acm.hdu.edu.cn/showproblem.php?pid=2222
大意:给你一堆单词,最后给你一个长串字符串,问你这个字符串里面包含多少个上述出现过的单词(可以重叠)
这道题用ac自动机可以将时间控制在300ms之内~
这道题的入门我是看的b站上面的ac自动机的视频,这里安利一下
看到这个算法猛然想到我几个星期之前打的华工校赛有一道题好像就是ac自动机的裸题,那个时候我的想法是for+kmp ,23333
其实这也是学了kmp的人看到这道题的正常想法,但是可惜用kmp会超时,kmp是单模式匹配,ac自动机这道题是多模式匹配。
ac自动机用的是kmp的next数组思想,在ac自动机上的next数组其实是fail数组(或者说是fail树),除了kmp之外主要是需要了解trie字典树,因为此题需要构建字典树。
下面我会写出我学习中遇到的一些问题以及我的想法
先贴代码
#include<cstdio>
#include<iostream>
#include<queue>
#include<algorithm>
#include<cstring>
using namespace std;
//const int maxn_n = 10000 + 7;
const int maxn_n = 10000 * 50;
const int maxn_len = 1e6 + 7;
struct ac_auto {
int trie[maxn_n][28];
int cnt[maxn_n];
int fail[maxn_n];
//cnt : 记录该点是不是一个单词的结尾 fail;失败数组
int last[maxn_n];
//解决单词重叠问题的数组
int ind,ans;
void init() {
ind = 1;
ans = 0; //ind:协助insert的,用于记录trie的编号
memset(trie[0], 0, sizeof(trie[0]));
memset(cnt, 0, sizeof(cnt));
memset(fail, 0, sizeof(fail));
}
void insert(char* p) {
int now = 0;
for (; *p; p++) { //还有这种操作。。。
int id = *p - 'a';
if (!trie[now][id]) { //假如还没有就创建
memset(trie[ind], 0, sizeof(trie[ind]));
trie[now][id] = ind++;
}
now = trie[now][id];
}
cnt[now]++;//记录该处一个为单词结束
}
void get_fail() { // bfs 建立 fail树
queue<int> Q;
for (int i = 0; i < 26; i++) {
if (trie[0][i]) {
Q.push(trie[0][i]);
last[trie[0][i]] = 0;
}
}
while (!Q.empty()) {
int now = Q.front();
Q.pop();
for (int i = 0; i < 26; i++) {
int to = trie[now][i];
if (!to) { // 假如没有,那么直接使用fail追溯回去
trie[now][i] = trie[fail[now]][i];
continue;
}
Q.push(to);
int tra = fail[now]; //tra=trace
while (tra && !trie[tra][i]) {//要么找到满足条件的祖先节点,要么就回根节点
tra = fail[tra];
}
fail[to] = trie[tra][i];
last[to] = cnt[fail[to]] ? fail[to] : last[fail[to]];
}
}
}
void solve(int i) {
if (!i) return;
if (cnt[i]) {
ans += cnt[i];
cnt[i] = 0;
}
solve(last[i]); //解决多串重叠问题
}
void find(char *p) {
int len = strlen(p);
int now = 0;
get_fail();
for (int i = 0; i < len; i++) {
int id = p[i] - 'a';
now = trie[now][id];
if (cnt[now]) solve(now);
else if (last[now]) solve(last[now]);
}
}
}AC;
char s[maxn_len],pp[55];
int main() {
int t; cin >> t;
while (t--) {
AC.init();
int n; scanf("%d", &n);
while (n--) {
scanf("%s", pp);
AC.insert(pp);
}
scanf("%s", s);
AC.find(s);
printf("%d\n", AC.ans);
}
return 0;
}
问题:
1.last数组是什么?
last数组就是解决多串重叠问题的数组
比如: 我给你 bc,abc ,最后给你的字符串是: qabcd
当匹配到c的时候,bc,abc都匹配得上,所以就这么写
看看last数组的生成方式:last[to] = cnt[fail[to]] ? fail[to] : last[fail[to]];
就是假如fail数组的那边也可以是一个字符的结尾的话,那么等于fail数组,不然就等于fail数组的last(- - 自己体会一下)
然后再结合这个看看
void solve(int i) {
if (!i) return;
if (cnt[i]) {
ans += cnt[i];
cnt[i] = 0;
}
solve(last[i]); //解决多串重叠问题
}
2. fail数组会不会彼此指向彼此
这个可以考虑一下bfs建立fail的方式,每个fail的指向一定是比它层数要高的,不可能存在这种情况