c++前缀树实现

c++前缀树实现

#include <iostream>
#include <vector>
#include <unordered_map>
#include <memory>

template<typename V>
class TrieMap {

public:

    class Option {
    public:
        V val{};
        bool isNone = true;
    };

private:
    class TrieNode {
    public:
        Option option;
        std::unordered_map<char, std::shared_ptr<TrieNode>> children;
    };

public:
    /***** 增/改 *****/

    // 在 Map 中添加 key
    void put(const std::string &key, V val) {
        if (!containsKey(key)) {
            word_size++;
        }

        root = putImpl(root, key, val, 0);
    }

    /***** 删 *****/

    // 删除键 key 以及对应的值
    void remove(const std::string &key) {
        if (!containsKey(key)) {
            return;
        }

        root = removeImpl(root, key, 0);

        word_size--;
    }

    /***** 查 *****/

    // 搜索 key 对应的值,不存在则返回 null
    // get("the") -> 4
    // get("tha") -> null
    TrieMap::Option get(const std::string &key) {
        auto x = getNode(root, key);
        if (x == nullptr || x->option.isNone) {
            TrieMap::Option tmp;

            return tmp;
        }

        return x->option;
    }

    // 判断 key 是否存在在 Map 中
    // containsKey("tea") -> false
    // containsKey("team") -> true
    bool containsKey(const std::string &key) {
        return !get(key).isNone;
    }

    // 在 Map 的所有键中搜索 query 的最短前缀
    // shortestPrefixOf("themxyz") -> "the"
    std::string shortestPrefixOf(const std::string &query) {
        auto p = root;
        int i = 0;
        for (auto ch: query) {
            if (!p->children.count(ch)) {
                return "";
            }

            if (!p->option.isNone) {
                return query.substr(0, i);
            }

            p = p->children[ch];
            i++;
        }

        if (p != nullptr && !p->option.isNone) {
            return query;
        }

        return "";

    }

    // 在 Map 的所有键中搜索 query 的最长前缀
    // longestPrefixOf("themxyz") -> "them"
    std::string longestPrefixOf(const std::string &query) {
        int max_len = 0;
        auto p = root;
        int i = 0;
        for (auto ch: query) {
            if (!p->children.count(ch)) {
                return "";
            }

            if (!p->option.isNone) {
                max_len = i;
            }

            p = p->children[ch];
            i++;
        }

        if (p != nullptr && !p->option.isNone) {
            return query;
        }

        return query.substr(0, max_len);
    }

    // 搜索所有前缀为 prefix 的键
    // keysWithPrefix("th") -> ["that", "the", "them"]
    std::vector<std::string> keysWithPrefix(const std::string &prefix) {
        std::vector<std::string> res{};

        auto x = getNode(root, prefix);

        if (x == nullptr) {
            return res;
        }

        traverse(x, prefix, res);
        return res;

    }

    // 判断是和否存在前缀为 prefix 的键
    // hasKeyWithPrefix("tha") -> true
    // hasKeyWithPrefix("apple") -> false
    bool hasKeyWithPrefix(const std::string &prefix) {
        return getNode(root, prefix) != nullptr;
    }

    // 通配符 . 匹配任意字符,搜索所有匹配的键
    // keysWithPattern("t.a.") -> ["team", "that"]
    std::vector<std::string> keysWithPattern(const std::string &pattern) {
        std::vector<std::string> res{};
        traversePattern(root, "", pattern, 0, res);
        return res;
    }

    // 通配符 . 匹配任意字符,判断是否存在匹配的键
    // hasKeyWithPattern(".ip") -> true
    // hasKeyWithPattern(".i") -> false
    bool hasKeyWithPattern(const std::string &pattern) {
        return hasKeyWithPatternImpl(root, pattern, 0);
    }

    // 返回 Map 中键值对的数量
    int size() {
        return word_size;
    }


private:
    std::shared_ptr<TrieNode> getNode(std::shared_ptr<TrieNode> node, const std::string &key) {
        auto p = node;

        if (node == nullptr) {
            return nullptr;
        }

        for (auto ch: key) {
            if (!p->children.count(ch)) {
                return nullptr;
            }

            p = p->children[ch];
        }

        return p;
    }


    void traverse(std::shared_ptr<TrieNode> node, std::string path, std::vector<std::string> &res) {
        if (node == nullptr) {
            return;
        }

        if (!node->option.isNone) {
            res.push_back(path);
        }

        for (const auto &item: node->children) {
            path.push_back(item.first);
            traverse(item.second, path, res);
            path.erase(path.end());
        }
    }

    void traversePattern(std::shared_ptr<TrieNode> node, std::string path, const std::string &pattern, int i,
                         std::vector<std::string> &res) {
        if (node == nullptr) {
            return;
        }

        if (i == pattern.length()) {
            if (!node->option.isNone) {
                res.push_back(path);
            }
            return;
        }

        char c = pattern[i];

        if (c == '.') {
            for (const auto &item: node->children) {
                path.push_back(item.first);
                traversePattern(item.second, path, pattern, i + 1, res);
                path.erase(path.end());
            }
        } else {
            path.push_back(c);
            traversePattern(node->children[c], path, pattern, i + 1, res);
            path.erase(path.end());
        }
    }

    bool hasKeyWithPatternImpl(std::shared_ptr<TrieNode> node, std::string pattern, int i) {
        if (node == nullptr) {
            return false;
        }

        if (i == pattern.length()) {
            return !node->option.isNone;
        }

        char c = pattern[i];
        if (c != '.') {
            return hasKeyWithPatternImpl(node->children[c], pattern, i + 1);
        }

        for (const auto &item: node->children) {
            if (hasKeyWithPatternImpl(item.second, pattern, i + 1)) {
                return true;
            }
        }

        return false;
    }

    std::shared_ptr<TrieNode> putImpl(std::shared_ptr<TrieNode> node, const std::string &key, V val, int i) {
        if (node == nullptr) {
            node = std::make_shared<TrieNode>();
        }

        if (i == key.length()) {
            node->option.isNone = false;
            node->option.val = val;
            return node;
        }

        char c = key[i];

        node->children[c] = putImpl(node->children[c], key, val, i + 1);
        return node;
    }


    std::shared_ptr<TrieNode> removeImpl(std::shared_ptr<TrieNode> node, const std::string &key, int i) {
        if (node == nullptr) {
            return nullptr;
        }

        if (i == key.length()) {
            node->option.isNone = true;
        } else {
            char c = key[i];
            node->children[c] = removeImpl(node->children[c], key, i++);
        }


        if (!node->option.isNone) {
            return node;
        }

        for (auto item: node->children) {
            return node;
        }

        return nullptr;
    }

private:
    int word_size = 0;
    std::shared_ptr<TrieNode> root{};
};

class TrieSet {
    // 底层用一个 TrieMap,键就是 TrieSet,值仅仅起到占位的作用
    // 值的类型可以随便设置,我参考 Java 标准库设置成 Object
private:
    TrieMap<int> map{};

    /***** 增 *****/

    // 在集合中添加元素 key
public:
    void add(const std::string &key) {
        map.put(key, -1);
    }

    /***** 删 *****/

    // 从集合中删除元素 key
    void remove(const std::string &key) {
        map.remove(key);
    }

    /***** 查 *****/

    // 判断元素 key 是否存在集合中
    bool contains(const std::string &key) {
        return map.containsKey(key);
    }

    // 在集合中寻找 query 的最短前缀
    std::string shortestPrefixOf(const std::string &query) {
        return map.shortestPrefixOf(query);
    }

    // 在集合中寻找 query 的最长前缀
    std::string longestPrefixOf(const std::string &query) {
        return map.longestPrefixOf(query);
    }

    // 在集合中搜索前缀为 prefix 的所有元素
    std::vector<std::string> keysWithPrefix(const std::string &prefix) {
        return map.keysWithPrefix(prefix);
    }

    // 判断集合中是否存在前缀为 prefix 的元素
    bool hasKeyWithPrefix(const std::string &prefix) {
        return map.hasKeyWithPrefix(prefix);
    }

    // 通配符 . 匹配任意字符,返回集合中匹配 pattern 的所有元素
    std::vector<std::string> keysWithPattern(const std::string &pattern) {
        return map.keysWithPattern(pattern);
    }

    // 通配符 . 匹配任意字符,判断集合中是否存在匹配 pattern 的元素
    bool hasKeyWithPattern(const std::string &pattern) {
        return map.hasKeyWithPattern(pattern);
    }

    // 返回集合中元素的个数
    int size() {
        return map.size();
    }
};


class Trie {
public:
    Trie() {

    }

    TrieSet mSet;


    void insert(std::string word) {
        mSet.add(word);
    }

    bool search(std::string word) {
        return mSet.contains(word);
    }

    bool startsWith(std::string prefix) {
        return mSet.hasKeyWithPrefix(prefix);
    }

};

int main() {
    printf("hello world \n");

    Trie trie;
    trie.insert("apple");
    auto res = trie.search("apple");   // 返回 True
    std::cout << res << std::endl;
    res = trie.search("app");     // 返回 False
    std::cout << res << std::endl;
    res = trie.startsWith("app"); // 返回 True
    std::cout << res << std::endl;
    trie.insert("app");
    res = trie.search("app");     // 返回 True
    std::cout << res << std::endl;
}
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值