题目:Keywords Search——http://acm.hdu.edu.cn/showproblem.php?pid=2222
题意:给 n 个模式串,一个文本串,询问能匹配的模式串的个数。
思路:ac自动机板子,注意文本串中有重复。
代码:
#include <iostream>
#include <cstdio>
#include <queue>
#include <cstring>
#include <algorithm>
#define LL long long
using namespace std;
const int maxn = 1e6+5;
const int siz = 26;
char s[maxn];
int t, n, dp[maxn], ans[maxn];
int head[maxn], cnt_edge;
struct Edge{
int to, nxt;
}edge[maxn];
void addedge(int from, int to){
edge[++cnt_edge].to = to;
edge[cnt_edge].nxt = head[from];
head[from] = cnt_edge;
}
//------------------------------------------------------//
struct AC_auto{
int tr[maxn][siz], cnt;
int fail[maxn];
int ed[maxn];
void Insert(char *s){ //插入单词
int p = 0, len = strlen(s);
for(int i=0; i<len; i++){
int v = s[i]-'a';
if(!tr[p][v]) tr[p][v] = ++cnt;
p = tr[p][v];
}
ed[p] ++;
}
void build_fail(){ //建立fail指针
queue<int>q;
for(int i=0; i<siz; i++) if(tr[0][i]) q.push(tr[0][i]), addedge(0, tr[0][i]);
while(q.size()){
int p = q.front(); q.pop();
for(int i=0; i<siz; i++){
if(tr[p][i]) {
fail[tr[p][i]] = tr[fail[p]][i];
q.push(tr[p][i]);
addedge(tr[fail[p]][i], tr[p][i]);
}else tr[p][i] = tr[fail[p]][i];
}
}
}
void ask(char *s){ //统计 s 经过的树的路径
int len = strlen(s), p = 0, ans = 0;
for(int i=0; i<len; i++){
p = tr[p][s[i]-'a'];
dp[p] ++;
}
}
void dfs(int x){ //更新dp[]
for(int i=head[x]; i!=-1; i=edge[i].nxt){
int y = edge[i].to;
dfs(y);
dp[x] += dp[y];
}
if(dp[x]) ans[x] = ed[x];
}
void clr(){ //初始化
memset(fail, 0, sizeof(fail));
memset(tr, 0, sizeof(tr));
memset(ed, 0, sizeof(ed));
cnt = 0;
}
void print(){ //debug
cout << "ed: "; for(int i=1; i<=cnt; i++) cout << ed[i] << " "; cout << endl;
cout << "fail: "; for(int i=1; i<=cnt; i++) cout << fail[i] << " "; cout << endl;
for(int i=0; i<cnt; i++){
for(int j=0; j<siz; j++){
if(tr[i][j])
printf("%d---%c----%d\n", i, (char)j+'a', tr[i][j]);
}
}
}
}AC;
//---------------------------------------------------------//
void init(){
AC.clr();
memset(ans, 0, sizeof(ans));
memset(dp, 0, sizeof(dp));
memset(head, -1, sizeof(head));
cnt_edge = 0;
}
int main()
{
scanf("%d", &t);
while(t--){
init();
scanf("%d", &n);
for(int i=1; i<=n; i++){
scanf("%s", s);
AC.Insert(s);
}
AC.build_fail();
scanf("%s", s);
AC.ask(s);
AC.dfs(0);
for(int i=1; i<=AC.cnt; i++) ans[0] += ans[i];
cout << ans[0] << endl;
}
}
/** output:5 3 4 6
5
aba
bab
ab
ba
ababa
abababababababab
3
she
she
she
shesheshe
6
she
he
he
say
shr
her
yasherhs
6
a
ba
cba
dcba
baf
f
dcbafd
*/