c++前缀树实现
#include <iostream>
#include <vector>
#include <unordered_map>
#include <memory>
template<typename V>
class TrieMap {
public:
class Option {
public:
V val{};
bool isNone = true;
};
private:
class TrieNode {
public:
Option option;
std::unordered_map<char, std::shared_ptr<TrieNode>> children;
};
public:
/***** 增/改 *****/
// 在 Map 中添加 key
void put(const std::string &key, V val) {
if (!containsKey(key)) {
word_size++;
}
root = putImpl(root, key, val, 0);
}
/***** 删 *****/
// 删除键 key 以及对应的值
void remove(const std::string &key) {
if (!containsKey(key)) {
return;
}
root = removeImpl(root, key, 0);
word_size--;
}
/***** 查 *****/
// 搜索 key 对应的值,不存在则返回 null
// get("the") -> 4
// get("tha") -> null
TrieMap::Option get(const std::string &key) {
auto x = getNode(root, key);
if (x == nullptr || x->option.isNone) {
TrieMap::Option tmp;
return tmp;
}
return x->option;
}
// 判断 key 是否存在在 Map 中
// containsKey("tea") -> false
// containsKey("team") -> true
bool containsKey(const std::string &key) {
return !get(key).isNone;
}
// 在 Map 的所有键中搜索 query 的最短前缀
// shortestPrefixOf("themxyz") -> "the"
std::string shortestPrefixOf(const std::string &query) {
auto p = root;
int i = 0;
for (auto ch: query) {
if (!p->children.count(ch)) {
return "";
}
if (!p->option.isNone) {
return query.substr(0, i);
}
p = p->children[ch];
i++;
}
if (p != nullptr && !p->option.isNone) {
return query;
}
return "";
}
// 在 Map 的所有键中搜索 query 的最长前缀
// longestPrefixOf("themxyz") -> "them"
std::string longestPrefixOf(const std::string &query) {
int max_len = 0;
auto p = root;
int i = 0;
for (auto ch: query) {
if (!p->children.count(ch)) {
return "";
}
if (!p->option.isNone) {
max_len = i;
}
p = p->children[ch];
i++;
}
if (p != nullptr && !p->option.isNone) {
return query;
}
return query.substr(0, max_len);
}
// 搜索所有前缀为 prefix 的键
// keysWithPrefix("th") -> ["that", "the", "them"]
std::vector<std::string> keysWithPrefix(const std::string &prefix) {
std::vector<std::string> res{};
auto x = getNode(root, prefix);
if (x == nullptr) {
return res;
}
traverse(x, prefix, res);
return res;
}
// 判断是和否存在前缀为 prefix 的键
// hasKeyWithPrefix("tha") -> true
// hasKeyWithPrefix("apple") -> false
bool hasKeyWithPrefix(const std::string &prefix) {
return getNode(root, prefix) != nullptr;
}
// 通配符 . 匹配任意字符,搜索所有匹配的键
// keysWithPattern("t.a.") -> ["team", "that"]
std::vector<std::string> keysWithPattern(const std::string &pattern) {
std::vector<std::string> res{};
traversePattern(root, "", pattern, 0, res);
return res;
}
// 通配符 . 匹配任意字符,判断是否存在匹配的键
// hasKeyWithPattern(".ip") -> true
// hasKeyWithPattern(".i") -> false
bool hasKeyWithPattern(const std::string &pattern) {
return hasKeyWithPatternImpl(root, pattern, 0);
}
// 返回 Map 中键值对的数量
int size() {
return word_size;
}
private:
std::shared_ptr<TrieNode> getNode(std::shared_ptr<TrieNode> node, const std::string &key) {
auto p = node;
if (node == nullptr) {
return nullptr;
}
for (auto ch: key) {
if (!p->children.count(ch)) {
return nullptr;
}
p = p->children[ch];
}
return p;
}
void traverse(std::shared_ptr<TrieNode> node, std::string path, std::vector<std::string> &res) {
if (node == nullptr) {
return;
}
if (!node->option.isNone) {
res.push_back(path);
}
for (const auto &item: node->children) {
path.push_back(item.first);
traverse(item.second, path, res);
path.erase(path.end());
}
}
void traversePattern(std::shared_ptr<TrieNode> node, std::string path, const std::string &pattern, int i,
std::vector<std::string> &res) {
if (node == nullptr) {
return;
}
if (i == pattern.length()) {
if (!node->option.isNone) {
res.push_back(path);
}
return;
}
char c = pattern[i];
if (c == '.') {
for (const auto &item: node->children) {
path.push_back(item.first);
traversePattern(item.second, path, pattern, i + 1, res);
path.erase(path.end());
}
} else {
path.push_back(c);
traversePattern(node->children[c], path, pattern, i + 1, res);
path.erase(path.end());
}
}
bool hasKeyWithPatternImpl(std::shared_ptr<TrieNode> node, std::string pattern, int i) {
if (node == nullptr) {
return false;
}
if (i == pattern.length()) {
return !node->option.isNone;
}
char c = pattern[i];
if (c != '.') {
return hasKeyWithPatternImpl(node->children[c], pattern, i + 1);
}
for (const auto &item: node->children) {
if (hasKeyWithPatternImpl(item.second, pattern, i + 1)) {
return true;
}
}
return false;
}
std::shared_ptr<TrieNode> putImpl(std::shared_ptr<TrieNode> node, const std::string &key, V val, int i) {
if (node == nullptr) {
node = std::make_shared<TrieNode>();
}
if (i == key.length()) {
node->option.isNone = false;
node->option.val = val;
return node;
}
char c = key[i];
node->children[c] = putImpl(node->children[c], key, val, i + 1);
return node;
}
std::shared_ptr<TrieNode> removeImpl(std::shared_ptr<TrieNode> node, const std::string &key, int i) {
if (node == nullptr) {
return nullptr;
}
if (i == key.length()) {
node->option.isNone = true;
} else {
char c = key[i];
node->children[c] = removeImpl(node->children[c], key, i++);
}
if (!node->option.isNone) {
return node;
}
for (auto item: node->children) {
return node;
}
return nullptr;
}
private:
int word_size = 0;
std::shared_ptr<TrieNode> root{};
};
class TrieSet {
// 底层用一个 TrieMap,键就是 TrieSet,值仅仅起到占位的作用
// 值的类型可以随便设置,我参考 Java 标准库设置成 Object
private:
TrieMap<int> map{};
/***** 增 *****/
// 在集合中添加元素 key
public:
void add(const std::string &key) {
map.put(key, -1);
}
/***** 删 *****/
// 从集合中删除元素 key
void remove(const std::string &key) {
map.remove(key);
}
/***** 查 *****/
// 判断元素 key 是否存在集合中
bool contains(const std::string &key) {
return map.containsKey(key);
}
// 在集合中寻找 query 的最短前缀
std::string shortestPrefixOf(const std::string &query) {
return map.shortestPrefixOf(query);
}
// 在集合中寻找 query 的最长前缀
std::string longestPrefixOf(const std::string &query) {
return map.longestPrefixOf(query);
}
// 在集合中搜索前缀为 prefix 的所有元素
std::vector<std::string> keysWithPrefix(const std::string &prefix) {
return map.keysWithPrefix(prefix);
}
// 判断集合中是否存在前缀为 prefix 的元素
bool hasKeyWithPrefix(const std::string &prefix) {
return map.hasKeyWithPrefix(prefix);
}
// 通配符 . 匹配任意字符,返回集合中匹配 pattern 的所有元素
std::vector<std::string> keysWithPattern(const std::string &pattern) {
return map.keysWithPattern(pattern);
}
// 通配符 . 匹配任意字符,判断集合中是否存在匹配 pattern 的元素
bool hasKeyWithPattern(const std::string &pattern) {
return map.hasKeyWithPattern(pattern);
}
// 返回集合中元素的个数
int size() {
return map.size();
}
};
class Trie {
public:
Trie() {
}
TrieSet mSet;
void insert(std::string word) {
mSet.add(word);
}
bool search(std::string word) {
return mSet.contains(word);
}
bool startsWith(std::string prefix) {
return mSet.hasKeyWithPrefix(prefix);
}
};
int main() {
printf("hello world \n");
Trie trie;
trie.insert("apple");
auto res = trie.search("apple"); // 返回 True
std::cout << res << std::endl;
res = trie.search("app"); // 返回 False
std::cout << res << std::endl;
res = trie.startsWith("app"); // 返回 True
std::cout << res << std::endl;
trie.insert("app");
res = trie.search("app"); // 返回 True
std::cout << res << std::endl;
}