先把匹配上的串都标记。然后把每个匹配上的串的子串变成非结束。
直接会超时。优化找失败指针的地方。
#include <iostream>
#include <cstdio>
#include <cmath>
#include <string>
#include <cstring>
#include <queue>
#include <algorithm>
using namespace std;
const int maxn = 250009;
char str1[5100009], str2[5100009];
char s[2509][1109];
char ss[1109];
int vis[maxn];
int cnt, n;
struct trie
{
trie *ch[26];
trie *fail;
int end, key;
void init()
{
memset(ch, NULL, sizeof(ch));
fail = NULL;
key = 0;
}
}*root, t[maxn];
void insert(char *s, int id) //正常的插入
{
trie *p = root;
for(int i=0; s[i]; i++)
{
int k = s[i]-'A';
if(p->ch[k]==NULL)
{
t[cnt].init();
p->ch[k] = &t[cnt++];
}
p = p->ch[k];
}
p->end = id;
p->key = 1;
}
void reinsert(char *a) //将其子串删掉
{
trie *p = root;
for(int i=0; a[i]; i++)
{
int k = a[i]-'A';
while(p->ch[k]==NULL && p!=root) p=p->fail;
p = p->ch[k];
p = (p==NULL)? root : p;
trie *tmp = p;
while(tmp!=root && tmp->key!=-3)
{
if(abs(tmp->key)==1 && vis[tmp->end])
{
vis[tmp->end] = 0;
}
tmp->key = -3;
tmp = tmp->fail;
}
}
}
queue<trie*>q;
void build_fail()
{
q.push(root);
root->fail = root;
while(!q.empty())
{
trie *p = q.front();
q.pop();
for(int k=0; k<26; k++)
if(p->ch[k]!=NULL)
{
q.push(p->ch[k]);
if(p==root)
{
p->ch[k]->fail = root;
continue;
}
trie *tmp = p->fail;
while( tmp!=root && tmp->ch[k]==NULL)
tmp = tmp->fail;
if(tmp->ch[k]) p->ch[k]->fail = tmp->ch[k];
else p->ch[k]->fail = root;
}
}
}
void query(char *a) //标记出现过的所有串
{
trie *p = root;
for(int i=0; a[i]; i++)
{
int k = a[i]-'A';
while(p->ch[k]==NULL && p!=root) p=p->fail;
p = p->ch[k];
p = (p==NULL)? root : p;
trie *tmp = p;
while(tmp!=root && tmp->key>=0)
{
if(tmp->key==1)
{
vis[tmp->end] = 1;
tmp->key = -1;
}
else tmp->key = -2;
tmp = tmp->fail;
}
}
}
int requery(char *a)
{
int ans = 0;
trie *p = root;
for(int i=0; a[i]; i++)
{
int k = a[i]-'A';
while(p->ch[k]==NULL && p!=root) p=p->fail;
p = p->ch[k];
p = (p==NULL)? root : p;
trie *tmp = p;
while(tmp!=root && tmp->key!=-4)
{
if(abs(tmp->key)==1) ans++;
tmp->key = -4;
tmp = tmp->fail;
}
}
return ans;
}
int main()
{
int tt, i, j, k;
scanf("%d", &tt);
while(tt--)
{
cnt = 1;
t[0].init();
root = &t[0];
scanf("%d", &n);
for(i = 1; i<=n; i++)
{
scanf("%s", ss);
for(k=0, j=0; ss[k]; k++)
{
if(ss[k]>='A' && ss[k]<='Z')
{
s[i][j] = ss[k];
j++;
}
else
{
k++;
int kk = ss[k]-'0';
k++;
while(ss[k]>='0'&&ss[k]<='9') kk = kk*10+ ss[k++]-'0';
while(kk--) s[i][j++] = ss[k];
k++;
}
}
s[i][j] = '\0';
insert(s[i], i);
}
build_fail();
scanf("%s", str1);
for(i=0, j=0; str1[i]; i++)
{
if(str1[i]>='A' && str1[i]<='Z')
{
str2[j] = str1[i];
j++;
}
else
{
i++;
int k = str1[i]-'0';
i++;
while(str1[i]>='0'&&str1[i]<='9') k = k*10+ str1[i++]-'0';
while(k--) str2[j++] = str1[i];
i++;
}
}
str2[j] = '\0';
memset(vis, 0, sizeof(vis));
query(str2);
for(int i=1; i<=n; i++)
if(vis[i])
{
reinsert( s[i] );
insert( s[i], i );
}
int ans = requery( str2 );
printf("%d\n", ans);
}
return 0;
}