http://acm.hdu.edu.cn/showproblem.php?pid=2222
题意:给出n个串,然后给一篇文章,问这n个串有多少个在文章里面出现过。。。
trick:n个串可能有相同的,需按照不同串处理。
分析:AC自动机模板题,自己按照算法思想写的,写得郁闷。。贴代码。。
不知道怎么要800+MS,别人的都基本上200+MS。。。
7月21号重新研究别人代码,终于找到差距了。。减到了187MS。。。附代码和改正说明:
7.21代码。。。
#include<iostream>
#include<stdio.h>
using namespace std;
const int N=1000010;
int n, a[N], up;
char s[N];
struct node
{
int a[26];
int cnt, fail;
void init()
{
memset(a, -1, sizeof(a));
cnt = fail = 0;
}
}trie[N];
inline void insert(char *s)
{
int p=0;
while(*s)
{
if(trie[p].a[*s-'a']==-1)
{
trie[up].init();
trie[p].a[*s-'a'] = up++;
}
p = trie[p].a[*s-'a'];
s++;
}
trie[p].cnt++;
}
int q[N], head, tail;
void bfs()
{
int i, p, p1, p2;
head = tail = 0;
for(i=0; i<26; i++)
{
if(trie[0].a[i]!=-1)
{
p = trie[0].a[i];
trie[p].fail = 0;
q[tail++] = p;
}
}
while(head<tail)
{
p = q[head++];
for(i=0; i<26; i++)
{
if(trie[p].a[i]!=-1)
{
p2 = trie[p].a[i];
q[tail++] = p2;
trie[p2].fail = 0;
p1 = trie[p].fail;
while(p1!=0 && trie[p1].a[i]==-1)
p1 = trie[p1].fail;
if(trie[p1].a[i]!=-1)
{
p1 = trie[p1].a[i];
trie[p2].fail = p1;
}
}
}
}
}
int query(char *s)
{
int p, p1;
int cnt=0;
p = 0;
while(*s)
{
while(p!=0 && trie[p].a[*s-'a']==-1)
p = trie[p].fail;
p1 = trie[p].a[*s-'a'];
if(p1!=-1)
{
p = p1;
while(p1!=0 && trie[p1].cnt!=-1) //这里很关键。。。但注意:要改成一个以前不会出现的值-1。。。
{
cnt += trie[p1].cnt;
trie[p1].cnt = -1;
p1 = trie[p1].fail;
}
/*
while(p1!=0) //通过将本段while循环改成上面,减掉了不必要的再次查找一系列fail结点,从800+MS减到了187MS。。。
{
if(trie[p1].flag==0)
{
cnt += trie[p1].cnt;
trie[p1].flag = 1;
}
p1 = trie[p1].fail;
}
*/
}
s++;
}
return cnt;
}
int main()
{
int i, cas;
scanf("%d", &cas);
while(cas--)
{
scanf("%d", &n);
gets(s);
up = 1;
trie[0].init();
for(i=0; i<n; i++)
{
gets(s);
insert(s);
}
gets(s);
bfs();
//for(i=0; i<up; i++)
// printf("%d %d..\n", i, trie[i].fail);
printf("%d\n", query(s));
}
return 0;
}
7.17号的代码。。。
//注意本题可能出现相同的keywords
//为什么要800+MS呢。。。
#include<iostream>
using namespace std;
const int N=1000100;
const int N1=250000;
int n, q[N], head, tail;
__int64 ans, cnt;
char s[N];
struct node
{
int fail;
int p[26];
int flag;
bool visited;
} tree[N1];
void insert(char *s)
{
int p=0;
while(*s)
{
if(tree[p].p[*s-'a']==-1)
tree[p].p[*s-'a'] = cnt++;
p = tree[p].p[*s-'a'];
s++;
}
tree[p].flag++;
}
void bfs()
{
int i, tmp; //把p做成指针
int pp;
head = tail = 0;
for(i=0; i<26; i++)
{
if(tree[0].p[i]!=-1)
{
pp = tree[0].p[i];
tree[pp].fail = 0;
q[tail++] = pp;
}
}
while(head<tail)
{
tmp = q[head++];
for(i=0; i<26; i++)
{
if(tree[tmp].p[i]!=-1)
{
q[tail++] = tree[tmp].p[i];
tree[tree[tmp].p[i]].fail = 0;
pp = tree[tmp].fail;
while(pp!=0 && tree[pp].p[i]==-1)
pp = tree[pp].fail;
if(tree[pp].p[i]!=-1)
tree[tree[tmp].p[i]].fail = tree[pp].p[i];
}
}
}
}
int query(char *s) //flag==1 && count==0 cnt++;
{
int i=0, len=strlen(s);
int tmp = 0, p=0;
__int64 cnt=0;
while(*s)
{
while(p!=0 && tree[p].p[*s-'a']==-1)
p = tree[p].fail; //找到最末尾一个满足条件的。。。
tmp = tree[p].p[*s-'a'];
if(tmp!=-1)
{
p = tmp;
while(tmp!=0)
{
if(tree[tmp].flag!=0 && tree[tmp].visited==0)// && tree[tmp].count==0)
{
cnt += tree[tmp].flag;
tree[tmp].visited = 1;
}
tmp = tree[tmp].fail;
}
}
s++;
}
return cnt;
}
int main()
{
int i, cas;
scanf("%d", &cas);
while(cas--)
{
scanf("%d", &n);
cnt = 1;
for(i=0; i<N1&&i<50*n; i++)
{
tree[i].fail = 0;
tree[i].flag = 0;
tree[i].visited = 0;
memset(tree[i].p, -1, sizeof(tree[i].p));
}
gets(s);
for(i=0; i<n; i++)
{
//scanf("%s", s);
gets(s);
insert(s);
}
//scanf("%s", s);
gets(s);
bfs();
/*
for(i=1; i<=10; i++) //
{
printf("%d ..fail = %d\n", i, tree[i].fail);
}
*/
ans = query(s);
printf("%I64d\n", ans);
}
return 0;
}