import collections
class TrieNode:
def __init__(self):
self.children = collections.defaultdict(TrieNode)
self.is_word = False
class Trie:
"""
In fact, this Trie is a letter three.
root is a fake node, its function is only the begin of a word, same as <bow>
the the first layer is all the word's possible first letter, for example, '中国'
its first letter is '中'
the second the layer is all the word's possible second letter.
and so on
"""
def __init__(self, use_single=True):
self.root = TrieNode()
self.max_depth = 0
if use_single:
self.min_len = 0
else:
self.min_len = 1
def insert(self, word):
current = self.root
deep = 0
for letter in word:
current = current.children[letter]
deep += 1
current.is_word = True
if deep > self.max_depth:
self.max_depth = deep
def search(self, word):
current = self.root
for letter in word:
current = current.children.get(letter)
if current is None:
return False
return current.is_word
def enumerateMatch(self, str1, space=""):
"""
Args:
str: 需要匹配的词
Return:
返回匹配的词, 如果存在多字词,则会筛去单字词
"""
matched = []
while len(str1) > self.min_len:
if self.search(str1):
matched.insert(0, space.join(str1[:])) # 短的词总是在最前面
del str1[-1]
if len(matched) > 1 and len(matched[0]) == 1: # filter single character word
matched = matched[1:]
return matched
a = Trie(use_single=True)
words = ['中国',"美国","世界","强国","世界强国",'中', '国', '和', '美', '国']
for word in words:
a.insert(word)
print(list("中国和美国都是世界强国"))
print(a.enumerateMatch(['中', '国', '和', '美', '国', '都', '是', '世', '界', '强', '国']))
进行句子搜索的代码如下:
def sent_to_matched_words_boundaries(sent, lexicon_tree, max_word_num=None):
"""
输入一个句子和词典树, 返回句子中每个字所属的匹配词, 以及该字的词边界
字可能属于以下几种边界:
B-: 词的开始, 0
M-: 词的中间, 1
E-: 词的结尾, 2
S-: 单字词, 3
BM-: 既是某个词的开始, 又是某个词中间, 4
BE-: 既是某个词开始,又是某个词结尾, 5
ME-: 既是某个词的中间,又是某个词结尾, 6
BME-: 词的开始、词的中间和词的结尾, 7
Args:
sent: 输入的句子, 一个字的数组
lexicon_tree: 词典树
max_word_num: 最多匹配的词的数量
Args:
sent_words: 句子中每个字归属的词组
sent_boundaries: 句子中每个字所属的边界类型
"""
sent_length = len(sent)
sent_words = [[] for _ in range(sent_length)]
sent_boundaries = [[] for _ in range(sent_length)] # each char has a boundary
for idx in range(sent_length):
sub_sent = sent[idx:idx + lexicon_tree.max_depth] # speed using max depth
words = lexicon_tree.enumerateMatch(sub_sent)
if len(words) == 0 and len(sent_boundaries[idx]) == 0:
sent_boundaries[idx].append(3) # S-
else:
if len(words) == 1 and len(words[0]) == 1: # single character word
if len(sent_words[idx]) == 0:
sent_words[idx].extend(words)
sent_boundaries[idx].append(3) # S-
else:
if max_word_num:
need_num = max_word_num - len(sent_words[idx])
words = words[:need_num]
sent_words[idx].extend(words)
for word in words:
if 0 not in sent_boundaries[idx]:
sent_boundaries[idx].append(0) # S-
start_pos = idx + 1
end_pos = idx + len(word) - 1
for tmp_j in range(start_pos, end_pos):
if 1 not in sent_boundaries[tmp_j]:
sent_boundaries[tmp_j].append(1) # M-
sent_words[tmp_j].append(word)
if 2 not in sent_boundaries[end_pos]:
sent_boundaries[end_pos].append(2) # E-
sent_words[end_pos].append(word)
assert len(sent_words) == len(sent_boundaries)
new_sent_boundaries = []
idx = 0
for boundary in sent_boundaries:
if len(boundary) == 0:
print("Error")
new_sent_boundaries.append(0)
elif len(boundary) == 1:
new_sent_boundaries.append(boundary[0])
elif len(boundary) == 2:
total_num = sum(boundary)
new_sent_boundaries.append(3 + total_num)
elif len(boundary) == 3:
new_sent_boundaries.append(7)
else:
print(boundary)
print("Error")
new_sent_boundaries.append(8)
assert len(sent_words) == len(new_sent_boundaries)
return sent_words, new_sent_boundaries
如果只想返回命中的词:
def sent_to_matched_words_set(sent, lexicon_tree, max_word_num=None):
"""return matched words set"""
sent_length = len(sent)
sent_words = [[] for _ in range(sent_length)]
matched_words_set = set()
for idx in range(sent_length):
sub_sent = sent[idx:idx + lexicon_tree.max_depth] # speed using max depth
words = lexicon_tree.enumerateMatch(sub_sent)
_ = [matched_words_set.add(word) for word in words]
matched_words_set = list(matched_words_set)
matched_words_set = sorted(matched_words_set)
return matched_words_set