字典树主要用与字符串的保存,,查询。
第一个:使用数组;
将根结点编号为0,然后把其余结点(包括叶子节点和非叶子节点)编号为从1开始的正整数,然后用一个数组来保存每个结点的所有子节点,用下标直接存取。
ch[i][j]保存结点i的那个编号为j的子结点的编号什么叫“编号为j”呢?比如,若是处理全部
由小写字母组成的字符串,把所有小写字母按照字典序编号为0,1,2,…,则ch[i][0]表示结点编号为i的a的下一个节点的编号。如果这个子结点不存在,则ch[i][0]=0。
sigma_size表示字符集的大小,比如,当字符集为全体小写字母时,sigma_size=26。
maxnode:表示节点的总个数
val[i]>0当且仅当结点i是单词结点。每个字符串有一个权值,就可以把这个权值保存在val[i]中。:val[i] ,一维数组,我当时考虑的是"abc" 和“abd”种情况,和“app”和“application”这两种情况,其实不用考虑,
解释"abc",“abd”;前两个是公用的ab,b这个节点所在编号的c的编号和d 下的编号是不同的,(c是一个节点有一个编号,d也是一个节点,也有一个编号,编号是不同的,所以,用编号就可以代表一个不同的字符,)对于单词而言只有末尾的那个字符对应的编号是有权值的,
struct Trie
{
int ch[maxnode][26];
int val[maxnode];
int sz; // 结点总数
void clear()
{
sz = 1; // 初始时只有一个根结点
memset(ch[0], 0, sizeof(ch[0]));
}
int idx(char c)
{
return c - 'a'; // 字符c的编号
}
// 插入字符串s,附加信息为v。注意v必须非0,因为0代表“本结点不是单词结点”
void insert(const char *s, int v)
{
int u = 0, n = strlen(s);
for(int i = 0; i < n; i++)
{
int c = idx(s[i]);
if(!ch[u][c]) // 结点不存在
{
memset(ch[sz], 0, sizeof(ch[sz]));
val[sz] = 0; // 中间结点的附加信息为0
ch[u][c] = sz++; // 新建结点
}
u = ch[u][c]; // 往下走
}
val[u] = v; // 字符串的最后一个字符的附加信息为v
}
// 找字符串s的长度不超过len的前缀
void find_prefixes(const char *s, int len, vector<int> &ans)
{
int u = 0;
for(int i = 0; i < len; i++)
{
if(s[i] == '\0') break;
int c = idx(s[i]);
if(!ch[u][c]) break;
u = ch[u][c];
if(val[u] != 0) ans.push_back(val[u]); // 找到一个前缀
}
}
};
第二种是树:
#include <bits/stdc++.h>
using namespace std;
struct node{
node* next[26];
node* fail;
int cnt;
node(){
for (int i = 0; i <26; ++i)next[i] = NULL;
cnt = 0;
fail = NULL;
}
}*q[500010];
node *root;
int head, tail;
char str[1000005];
void insert(char *str)
{
node *p = root;
int i = 0, index;
while (str[i]){
index = str[i] - 'a';
if (p->next[index] == NULL)p->next[index] = new node();
p = p->next[index];
++i;
}
++p->cnt;
}
void build(node* root)
{
root->fail = NULL;
q[tail++] = root;
while (head < tail){
node* temp = q[head++];
node* p = NULL;
for (int i = 0; i < 26; ++i){
if (temp->next[i] != NULL){
if (temp == root)temp->next[i]->fail = root;
else{
p = temp->fail;
while (p != NULL){
if (p->next[i] != NULL){
temp->next[i]->fail = p->next[i];
break;
}
p = p->fail;
}
if (p == NULL)temp->next[i]->fail = root;
}
q[tail++] = temp->next[i];
}
}
}
}
int query(node *root)
{
int i = 0, cnt = 0, index;
node* p = root;
while (str[i]){
index = str[i] - 'a';
while (p->next[index] == NULL && p != root)p = p->fail;
p = p->next[index];
if (p == NULL)p = root;
node* temp = p;
while (temp != root && temp->cnt != -1){
cnt += temp->cnt;
temp->cnt = -1;
temp = temp->fail;
}
++i;
}
return cnt;
}
int main()
{
int t, n;
scanf("%d", &t);
while (t--){
head = tail = 0;
root = new node();
scanf("%d", &n);
while(n--){
scanf("%s", str);
insert(str);
}
build(root);
scanf("%s", str);
printf("%d\n", query(root));
}
return 0;
}
AC自动机
参考博客
题意第一行输入测试数据的组数,然后输入一个整数n,接下来的n行每行输入一个单词,最后输入一个字符串,问在这个字符串中有多少个单词出现过
#include<bits/stdc++.h>
using namespace std;
const int maxn = 1e7 + 5;
const int MAX = 10000000;
int cnt;
struct node{
node *next[26];
node *fail;
int sum;
};
node *root;
char key[70];
node *q[MAX];
int head,tail;
node *newnode;
char pattern[maxn];
int N;
void Insert(char *s)
{
node *p = root;
for(int i = 0; s[i]; i++)
{
int x = s[i] - 'a';
if(p->next[x] == NULL)
{
newnode=(struct node *)malloc(sizeof(struct node));
for(int j=0;j<26;j++) newnode->next[j] = 0;
newnode->sum = 0;newnode->fail = 0;
p->next[x]=newnode;
}
p = p->next[x];
}
p->sum++;
}
void build_fail_pointer()
{
head = 0;
tail = 1;
q[head] = root;
node *p;
node *temp;
while(head < tail)
{
temp = q[head++];
for(int i = 0; i <= 25; i++)
{
if(temp->next[i])
{
if(temp == root)
{
temp->next[i]->fail = root;
}
else
{
p = temp->fail;
while(p)
{
if(p->next[i])
{
temp->next[i]->fail = p->next[i];
break;
}
p = p->fail;
}
if(p == NULL) temp->next[i]->fail = root;
}
q[tail++] = temp->next[i];
}
}
}
}
void ac_automation(char *ch)
{
node *p = root;
int len = strlen(ch);
for(int i = 0; i < len; i++)
{
int x = ch[i] - 'a';
while(!p->next[x] && p != root) p = p->fail;
p = p->next[x];
if(!p) p = root;
node *temp = p;
while(temp != root)
{
if(temp->sum >= 0)
{
cnt += temp->sum;
temp->sum = -1;
}
else break;
temp = temp->fail;
}
}
}
int main()
{
int T;
scanf("%d",&T);
while(T--)
{
root=(struct node *)malloc(sizeof(struct node));
for(int j=0;j<26;j++) root->next[j] = 0;
root->fail = 0;
root->sum = 0;
scanf("%d",&N);
getchar();
for(int i = 1; i <= N; i++)
{
gets(key);
Insert(key);
}
gets(pattern);
cnt = 0;
build_fail_pointer();
ac_automation(pattern);
printf("%d\n",cnt);
}
return 0;
}