题目链接:https://www.luogu.org/problem/P3796
题意:给多个模式串和一个文本串,要求出模式串匹配最多的次数是多少,并输出这些次数的模式串。
分析:求匹配次数很明显,把AC_qurey函数稍微改一下就好,但要输出这些模式串我就迷了,当时认为Trie只能上到下,到了一个结点无法回到其上的结点,就不会了,事实上还是很浅显,另外用一个数组记录位置就好了啊。。。。将字典树结点中的end改为只有其是一个单词的结尾才做标记,并且标记这个单词的次序。
clean函数应该就是AC自动机里的初始化了,因为不知道会用到多少边,故是每当cnt增加了(即出现了新的一条边)的时候我们再初始化
#include<bits/stdc++.h> using namespace std; const int maxn=1e6+10; const int inf=0x3f3f3f3f; typedef long long ll; #define meminf(a) memset(a,0x3f,sizeof(a)) #define mem0(a) memset(a,0,sizeof(a)); struct result{ int num,pos; }ans[200]; bool cmp(const result &a,const result &b){ if(a.num==b.num) return a.pos<b.pos; return a.num>b.num; } struct node{ int fail;//失配指针fail int vis[26];//子节点的位置,也就是字典树的那26个字母 int end;//如果是尾节点就记录 }AC[maxn]; char s[200][100];//用来输入模式串 char ss[maxn]; //用来输入文本串 int cnt=0;//Trie的指针 void clean(int x){ mem0(AC[x].vis); AC[x].end=0; AC[x].fail=0; } void insert(char *s,int pos){ int len=strlen(s); int now=0;//字典树的当前指针 for(int i=0;i<len;i++){ //Trie树没有这个子节点 if(AC[now].vis[s[i]-'a']==0) AC[now].vis[s[i]-'a']=++cnt,clean(cnt); //多组输入,需要清除 一个个清除,之前++cnt说明需要用到这个节点了 now=AC[now].vis[s[i]-'a']; } AC[now].end=pos;//标记该结点是一个单词的结尾 ,并标记这是第几个单词 } void get_fail(){ queue<int> que; for(int i=0;i<26;i++){//把第二层的fail指针都设为0 if(AC[0].vis[i]!=0) { AC[AC[0].vis[i]].fail=0; que.push(AC[0].vis[i]); } } while(!que.empty()) { int u=que.front();que.pop(); for(int i=0;i<26;i++){ if(AC[u].vis[i]!=0){ //如果当前结点的子节点存在,就将子节点的fail指针指向当前结点fail指针指向的结点的对应子节点处 AC[AC[u].vis[i]].fail=AC[AC[u].fail].vis[i]; que.push(AC[u].vis[i]); } else AC[u].vis[i]=AC[AC[u].fail].vis[i]; //否则直接将这个不存在的子节点指向当前结点fail指针指向结点的对应子节点处 } } } void AC_query(char* s){ int len=strlen(s); int now=0; for(int i=0;i<len;i++){ now=AC[now].vis[s[i]-'a']; for(int t=now;t!=0;t=AC[t].fail){ ans[AC[t].end].num++; } } } int main(){ int n; while(~scanf("%d",&n)) { if(n==0) break; cnt=0; clean(0);//每次到一个新的边时,都记得清空之前的数据 for(int i=1;i<=n;i++){ ans[i].num=0; ans[i].pos=i; scanf("%s",s[i]); insert(s[i],i); } AC[0].fail=0;//结束标志 get_fail(); //求出失配指针 scanf("%s",ss); AC_query(ss); sort(ans+1,ans+1+n,cmp); printf("%d\n",ans[1].num); for(int i=1;i<=n;i++){ if(ans[i].num==ans[1].num){ printf("%s\n",s[ans[i].pos]); } else break; } } return 0; }