类型1:判断单词前缀有没有出现过
#include<bits/stdc++.h>
using namespace std;
const int maxn=2e6+5;
int trie[maxn][27];
int top;
void insert(char *s,int rt)
{
for(int i=0;i<strlen(s);i++)
{
int x=s[i]-'a';// 用x(0-25)记录字符a-z
if(trie[rt][x]==0)//如果为 0代表该字符没有被插入过,
{
//插入该字符,并记录此次插入标号
trie[rt][x]=++top;//按照插入的顺序记录编号
}
rt=trie[rt][x];//找下个节点
}
return ;
}
bool find(char *s,int rt)
{
for(int i=0;i<strlen(s);i++)//在字典树中找 需要寻找的单词
{ //遍历长度为单词长度,如果单词都存在就返回 1
int x=s[i]-'a';
if(trie[rt][x]==0) return 0;//只要有一个不满足返回0
rt=trie[rt][x];//找下个字母
}
return 1;
}
int main()
{
int n;
scanf("%d",&n);
char s[1000];
top=0;
int rt=0;
for(int i=0;i<n;i++)
{
cin>>s;
insert(s,rt);
}
int m;
scanf("%d",&m);
for(int i=0;i<m;i++)
{
cin>>s;
if(find(s,rt)) printf("YES\n");
else printf("NO\n");
}
}
类型2:给出一些单词,判断每个单词是否出现过(注意是整个单词,不是前缀),其实这种类型和第一种情况差不多,用一个vis记录单词结束就可以了
#include<bits/stdc++.h>
using namespace std;
const int maxn=2e6+10;
int trie[2000010][27];
bool vis[maxn];
//int sum[maxn];
int top;
void insert(char *s,int rt)
{
for(int i=0;i<strlen(s);i++)
{
int x=s[i]-'a';
if(trie[rt][x]==0)
{
trie[rt][x]=++top;
}
rt=trie[rt][x];
}
vis[rt]=1;
return ;
}
bool find(char *s,int rt)
{
for(int i=0;i<strlen(s);i++)
{
int x=s[i]-'a';
if(trie[rt][x]==0) return 0;
rt=trie[rt][x];
}
return vis[rt];
}
int main()
{
int n;
scanf("%d",&n);
char s[1000];
top=0;
int rt=0;
for(int i=0;i<n;i++)
{
cin>>s;
insert(s,rt);
}
int m;
scanf("%d",&m);
for(int i=0;i<m;i++)
{
cin>>s;
//printf("%d\n",find(s,rt));
if(find(s,rt)) printf("YES\n");
else printf("NO\n");
}
}
类型3:给出一些单词,输出一个单词A,问在这些单词里有多少单词满足前缀有A
也是一样,直接用一个sum[ ]数组记录一下次数就可以了
#include<bits/stdc++.h>
using namespace std;
const int maxn=2e6+10;
int trie[2000010][27];
bool vis[maxn];
int sum[maxn];//记录次数
int top;
void insert(char *s,int rt)
{
for(int i=0;i<strlen(s);i++)
{
int x=s[i]-'a';
if(trie[rt][x]==0)
{
trie[rt][x]=++top;
}
sum[trie[rt][x]]++;//次数++就可以了
rt=trie[rt][x];
}
vis[rt]=1;
return ;
}
int find(char *s,int rt)
{
for(int i=0;i<strlen(s);i++)
{
int x=s[i]-'a';
if(trie[rt][x]==0) return 0;
rt=trie[rt][x];
}
return sum[rt];//返回结束时的sum
}
int main()
{
int n;
scanf("%d",&n);
char s[1000];
top=0;
int rt=0;
while(gets(s)&&s[0]!='\0')
{
//cout<<s<<endl;
insert(s,rt);
}
while(cin>>s)
{
printf("%d\n",find(s,rt));
}
}