字典树trie
1. 原理
先看一张图
看不懂很正常
如果你看懂了,那么原理就可以直接跳过
比如说我们现在有几个字符串
in
inn
int
to
ten
tea
我们想知道te
是哪几个字符串的前缀?
这个时候我们显然可以一个一个去比较,但是这样的效率有点低,我们也可以想到用字符串hash去比较,但是这样的效率还是有点低
这个时候字典树就排上用场了
字典树 说白了 就是一个字典
和我们平时查英语的字典一样,比如我们现在要去查一个单词hello
我们首先第一个动作一定是放到 H 开头的那一页
再去放到 he 开头的那一页
直到我们找到了hello
读到这里了,你再回过头去看上面那张图,一切就都豁然开朗了
2.代码
一份标准的字典树 包含两个函数 insert()
和find()
也许一开始会很懵逼,但是跟着敲一次你也许就能明白了
下面有例题,可以帮助理解一下
const int maxn = 1e6 + 10;
int trie[maxn][26]; // 存储下一个字符的位置
int num[maxn]; // 以当前字符串为前缀的单词的数量
int pos = 1; // 新分配的位置
void insert(char *str)
{
int p = 0;
for (int i = 0; str[i]; i++)
{
int n = str[i] - 'A';
if (trie[p][n] == 0) // 从未遇到过的字符串
{
trie[p][n] = pos++;
}
p = trie[p][n];
num[p]++;
}
}
int find(char *str)
{
int p = 0;
for (int i = 0; str[i]; i++)
{
int n = str[i]-'A';
if(trie[p][n]==0) return 0;
p = trie[p][n];
}
return num[p];
}
3. 例题
两道模板题
-
https://ac.nowcoder.com/acm/problem/16864
// https://ac.nowcoder.com/acm/problem/16864 #include <iostream> #include <cstdio> #include <cstdlib> #include <cstring> #include <algorithm> typedef long long ll; typedef unsigned long long ull; using namespace std; const int maxn = 1e6 + 10; int trie[maxn][26]; int num[maxn]; int pos = 1; void insert(char *str) { int p = 0; for (int i = 0; str[i]; i++) { int n = str[i] - 'A'; if (trie[p][n] == 0) { trie[p][n] = pos++; } p = trie[p][n]; num[p]++; } } int find(char *str) { int p = 0; for (int i = 0; str[i]; i++) { int n = str[i]-'A'; if(trie[p][n]==0) return 0; p = trie[p][n]; } return num[p]; } int main() { char str[maxn]; while(scanf("%s",str)!=EOF){ insert(str); } printf("%d",pos); return 0; }
-
https://www.luogu.com.cn/problem/UVA644
// https://www.luogu.com.cn/problem/UVA644 #include <iostream> #include <cstdio> #include <cstring> #include <cstdlib> #include <algorithm> using namespace std; typedef long long ll; const int maxn = 1000+10; int trie[maxn][2]; bool wend[maxn]; char str[maxn]; bool flag; int pos = 1; void init() { memset(str, 0, sizeof(str)); memset(trie, 0, sizeof(trie)); memset(wend, 0, sizeof(wend)); flag = 0; pos = 1; } void insert(char *str) { int p = 0; for (int i = 0; str[i]; i++) { int n = str[i] - '0'; if (trie[p][n] == 0) { trie[p][n] = pos++; } if (wend[p] == 1) { flag = 1; return; } p = trie[p][n]; if (str[i + 1] == 0) { if (wend[p] == 1||trie[p][0]!=0||trie[p][1]!=0) flag = 1; wend[p] = 1; } } } int main() { int kase = 1; while (scanf("%s", str) != EOF) { if (str[0] == '9') { if (flag) { printf("Set %d is not immediately decodable\n", kase++); } else { printf("Set %d is immediately decodable\n", kase++); } init(); continue; } if (!flag) insert(str); } return 0; }