python 实现 Trie 树并返回句子搜索命中的词

28 篇文章 0 订阅
import collections

class TrieNode:
    def __init__(self):
        self.children = collections.defaultdict(TrieNode)
        self.is_word = False

class Trie:
    """
    In fact, this Trie is a letter three.
    root is a fake node, its function is only the begin of a word, same as <bow>
    the the first layer is all the word's possible first letter, for example, '中国'
        its first letter is '中'
    the second the layer is all the word's possible second letter.
    and so on
    """
    def __init__(self, use_single=True):
        self.root = TrieNode()
        self.max_depth = 0
        if use_single:
            self.min_len = 0
        else:
            self.min_len = 1

    def insert(self, word):
        current = self.root
        deep = 0
        for letter in word:
            current = current.children[letter]
            deep += 1
        current.is_word = True
        if deep > self.max_depth:
            self.max_depth = deep

    def search(self, word):
        current = self.root
        for letter in word:
            current = current.children.get(letter)

            if current is None:
                return False
        return current.is_word

    def enumerateMatch(self, str1, space=""):
        """
        Args:
            str: 需要匹配的词
        Return:
            返回匹配的词, 如果存在多字词,则会筛去单字词
        """
        matched = []
        while len(str1) > self.min_len:
            if self.search(str1):
                matched.insert(0, space.join(str1[:])) # 短的词总是在最前面
            del str1[-1]

        if len(matched) > 1 and len(matched[0]) == 1: # filter single character word
            matched = matched[1:]

        return matched

a = Trie(use_single=True)
words = ['中国',"美国","世界","强国","世界强国",'中', '国', '和', '美', '国']
for word in words:
    a.insert(word)
print(list("中国和美国都是世界强国"))
print(a.enumerateMatch(['中', '国', '和', '美', '国', '都', '是', '世', '界', '强', '国']))

进行句子搜索的代码如下:

def sent_to_matched_words_boundaries(sent, lexicon_tree, max_word_num=None):
    """
    输入一个句子和词典树, 返回句子中每个字所属的匹配词, 以及该字的词边界
    字可能属于以下几种边界:
        B-: 词的开始, 0
        M-: 词的中间, 1
        E-: 词的结尾, 2
        S-: 单字词, 3
        BM-: 既是某个词的开始, 又是某个词中间, 4
        BE-: 既是某个词开始,又是某个词结尾, 5
        ME-: 既是某个词的中间,又是某个词结尾, 6
        BME-: 词的开始、词的中间和词的结尾, 7

    Args:
        sent: 输入的句子, 一个字的数组
        lexicon_tree: 词典树
        max_word_num: 最多匹配的词的数量
    Args:
        sent_words: 句子中每个字归属的词组
        sent_boundaries: 句子中每个字所属的边界类型
    """
    sent_length = len(sent)
    sent_words = [[] for _ in range(sent_length)]
    sent_boundaries = [[] for _ in range(sent_length)]  # each char has a boundary

    for idx in range(sent_length):
        sub_sent = sent[idx:idx + lexicon_tree.max_depth]  # speed using max depth
        words = lexicon_tree.enumerateMatch(sub_sent)

        if len(words) == 0 and len(sent_boundaries[idx]) == 0:
            sent_boundaries[idx].append(3) # S-
        else:
            if len(words) == 1 and len(words[0]) == 1: # single character word
                if len(sent_words[idx]) == 0:
                    sent_words[idx].extend(words)
                    sent_boundaries[idx].append(3) # S-
            else:
                if max_word_num:
                    need_num = max_word_num - len(sent_words[idx])
                    words = words[:need_num]
                sent_words[idx].extend(words)
                for word in words:
                    if 0 not in sent_boundaries[idx]:
                        sent_boundaries[idx].append(0) # S-
                    start_pos = idx + 1
                    end_pos = idx + len(word) - 1
                    for tmp_j in range(start_pos, end_pos):
                        if 1 not in sent_boundaries[tmp_j]:
                            sent_boundaries[tmp_j].append(1) # M-
                        sent_words[tmp_j].append(word)
                    if 2 not in sent_boundaries[end_pos]:
                        sent_boundaries[end_pos].append(2) # E-
                    sent_words[end_pos].append(word)

    assert len(sent_words) == len(sent_boundaries)

    new_sent_boundaries = []
    idx = 0
    for boundary in sent_boundaries:
        if len(boundary) == 0:
            print("Error")
            new_sent_boundaries.append(0)
        elif len(boundary) == 1:
            new_sent_boundaries.append(boundary[0])
        elif len(boundary) == 2:
            total_num = sum(boundary)
            new_sent_boundaries.append(3 + total_num)
        elif len(boundary) == 3:
            new_sent_boundaries.append(7)
        else:
            print(boundary)
            print("Error")
            new_sent_boundaries.append(8)
    assert len(sent_words) == len(new_sent_boundaries)

    return sent_words, new_sent_boundaries

如果只想返回命中的词:

def sent_to_matched_words_set(sent, lexicon_tree, max_word_num=None):
    """return matched words set"""
    sent_length = len(sent)
    sent_words = [[] for _ in range(sent_length)]
    matched_words_set = set()
    for idx in range(sent_length):
        sub_sent = sent[idx:idx + lexicon_tree.max_depth]  # speed using max depth
        words = lexicon_tree.enumerateMatch(sub_sent)

        _ = [matched_words_set.add(word) for word in words]
    matched_words_set = list(matched_words_set)
    matched_words_set = sorted(matched_words_set)
    return matched_words_set

 

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

samoyan

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值