引入
判断一个字符串是否出现在另一个字符串中,暴力的解法肯定是行不通的,而利用字典树这个数据结构就可以以 logn 的时间处理查询问题。
字典树的建树(或者说是添加操作)的原理是这样的,先从字符串第一个字符开始遍历,再从根目录找相应的字符,如果能找到则继续向下找,直到找到最后一个字符处打上标记,代表有一个字符串以这个字符结束。如果中间不能找到,那么就再刚刚查找的那一层同层建立一个新节点赋为正在找的字符,往后每一个字符都自己新建节点,直到最后打上标记。
如图为字符串 "ab" "abc" "abd" "bc" 添加到树里的结果 :绿色代表打标记,如果要考虑计数问题,可以用 int 变量作为标记
字典树就是按照我们查字典的习惯建立的数,先从根部开始找第一个字母,找到后再以这个字母向下层找第二个字母,以此类推,直到找到最后一个字母,如果都能找到,则代表这个字符串在这个给定字符串中。
实现
在实际算法竞赛中,如果要用到字典树往往会使用数组来模拟,优点是速度快,而且c++的STL中没有树的容器,因此数组模拟是最优解。由于每一层都有可能有26个字母,因此用二维数组来存储:tire[N][26],每一个节点都会有自己的编号,因此实际上我们只需要存储编号就可以,用int 存储,标记数组一维便可以,cnt[N] 代表到编号为 i 的字符串是否出现,出现几次。
实际上树是这个样子:
int tire[100000][26], idx;
int cnt[100000]; // 该结点结尾的字符串是否存在
void insert(string s) { // 插入字符串
int p = 0;
for (int i = 0; i < s.size(); i++) {
int x = s[i] - 'a';
if (!tire[p][x]) tire[p][x] = ++idx; // 如果没有,就添加结点
p = tire[p][x];
}
cnt[p] ++;
}
int query(string s) { // 查找字符串
int p = 0;
for (int i = 0; i < s.size(); i++) {
int x = s[i] - 'a';
if (!tire[p][x]) return 0;
p = tire[p][x];
}
return cnt[p]; // 只有两个字符串一模一样才能返回非0,如果只是想查询是否包含,此处直接返回1
}
讲解
1.变量idx: idx代表字典树中每一个节点的编号。
2.trie[N][26]: 其中1~N为上方节点的编号,0代表root节点,1~26为连在i节点下方的26个字母。如果trie[i][x]=0,则代表字典树中目前没有这个点,而trie[i][x]的值代表这个点下方连有的点的编号。
3.cnt[N]: cnt[i] == 0代表编号为i的点不是一个单词的结束点,cnt[i]!=0代表编号为i的点是一个单词的结束点。
4.(难点)变量p:
p代表当前节点编号,初始化为0,代表初始节点,在函数的循环中,我们首先用 x 确定接下来要找的字母,再用 tire[p][x] 判断是否有下个字符。如果目标节点存在,就把p更新成目标节点的编号(p = trie[p][x])。而如果trie[p][x] == 0,代表字典树中没有这个点,查找失败。而如果是插入函数,我们就用 ++id 来把这个点存进字典树。我们在两个函数的最后用cnt[p]来标记。
例题
P2580 于是他错误的点名开始了 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)
题解:
// Problem: P2580 于是他错误的点名开始了
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P2580
// Memory Limit: 128 MB
// Time Limit: 1000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include<bits/stdc++.h>
using namespace std;
using ll = long long;
const int N = 3e5 + 9;
#define int long long
int n, m;
map<string, bool> mp;
int tire[N][26], idx;
int cnt[N]; // 该结点结尾的字符串是否存在
void insert(string s) { // 插入字符串
int p = 0;
int tt = s.size();
for (int i = 0; i < tt; i++) {
int x = s[i] - 'a';
if (!tire[p][x]) tire[p][x] = ++idx; // 如果没有,就添加结点
p = tire[p][x];
}
cnt[p] ++;
}
int query(string s) { // 查找字符串
int p = 0;
int tt = s.size();
for (int i = 0; i < tt; i++) {
int x = s[i] - 'a';
if (!tire[p][x]) return 0;
p = tire[p][x];
}
return cnt[p];
}
void solve() {
cin >> n;
for (int i = 1; i <= n; ++ i) {
string s; cin >> s;
insert(s);
}
cin >> m;
for (int i = 1; i <= m; ++ i) {
string s;cin >> s;
if (query(s)) {
if (!mp[s]) {
cout << "OK\n";
mp[s] = 1;
}else {
cout << "REPEAT\n";
}
}
else cout << "WRONG\n";
}
}
signed main() {
ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
int _ = 1;
//cin >> _;
while (_--) {
solve();
}
return 0;
}
总结
自用复习以及分享给大家,有任何问题可以留言或私信。