字典树
1、字典树
字典树:trie树,用树结构来描述词典。
- 树状结构
- 每条边代表一个字符,字符串是一条路径
- 节点可以存储value
- 单词对应的是路径
字典树相对于普通树的结构来说,就是类似带权重的树,只是权重不是数字,而是每一个字符。同时当节点value不为none时,表示从根节点到该节点的路径对应的字符串就是一个词。
实现方案:
- 节点:就和二叉树类似,有个节点类,但是字典树节点需要保存子树和value.子树需要保存路径上的字符,显然用dict来实现。
- 树:主要要实现查询功能,其实就是按照查询字符来遍历node dict来判断是否存在。删除就是把对应字符串的value置为none,修改也是修改value值。
代码实现:
class Node:
def __init__(self, value):
self._children = {}
self._value = value
def _add_child(self, char, value, overwrite=False):
child = self._children.get(char)
if child is None:
child = Node(value)
self._children[char] = child
elif overwrite:
child._value = value
return child
class Trie(Node):
def __init__(self):
super().__init__(None)
def __contains__(self, item):
return self[item] is not None
def __getitem__(self, item):
state = self
for char in item:
state = state._children.get(char)
if state is None:
return None
return state._value
def __setitem__(self, key, value):
state = self
for i, char in enumerate(key):
if i<len(key)-1:
state = state._add_child(char, None, False)
else:
state = state._add_child(char, value, True)
if __name__ == '__main__':
trie = Trie()
# add
trie['自然'] = 'nature'
trie['自然人'] = 'human'
trie['自然语言'] = 'language'
trie['自语'] = 'talk to oneself'
trie['入门'] = 'introduction'
assert '自然' in trie
#delete
trie['自然'] = None
#modify
trie['自然语言处理'] = 'human language'
2、首字散列其余二分的字典树
首字散列其余二分的字典树:按照字面意思就能知道这个和上面的字典树的区别,它借用散列表来实现对第一个字符快速的查找,然后后面字符按照二分法查找。
实现方案:
原书中是提供了Java实现的代码路径,地址:https://github.com/hankcs/HanLP/blob/1.x/src/main/java/com/hankcs/hanlp/collection/trie/bintrie/BinTrie.java。
按照Java实现思路,我写了一下Python的实现。
- 散列表:按照书中提到的说法,因为Python散列函数返回值是64位的,远超于单个汉字的数目,还有就是对于意思相近字得到的hash值偏离比较远。第二点理由其实我有点不太理解,即使hash值偏差大应该也没太大问题吧。那么对于散列表的实现就是借助hash函数和list来实现。hash函数要么借用Java的hash函数要么自己写一个简单的hash函数。补充一下,当前查找中文单字的数量:《汉语大字典》现收楷书单字60370个。还有为啥只对首字散列其实也比较好理解,汉语中二字词最多。
- 二分查找:二分查找是对字符进行查找,需要对字符进行比较。其他的和普通的二分查找没有区别了。
- 节点:节点比普通的字典树节点多了节点状态这个属性,如果是单纯实现这个数据结构应该是不需要这个属性。但是可以方便理解和记录。
代码实现:
##base_node.py
from enum import Enum
class BaseNode:
def __init__(self, key=None, status=None, value=None):
self._children = []
self.key = key
self.value = value
self.status = status
def get_key(self):
return self.key
def get_value(self):
return self.value
def set_value(self, value):
self.value = value
def compare_to(self, other):
key = other.key if isinstance(other, BaseNode) else other
if self.key > key:
return 1
elif self.key < key:
return -1
else:
return 0
class Status(Enum):
# 为序列值指定value值
# 未指定,用于删除词条
UNDEFINED_0 = 0
# 不是词语的结尾
NOT_WORD_1 = 1
# 是个词语的结尾,并且还可以继续
WORD_MIDDLE_2 = 2
# 是个词语的结尾,并且没有继续
WORD_END_3 = 3
## array_tool.py 工具类
class ArrayTool:
@staticmethod
def binary_search(branches, node):
high = len(branches) - 1
if len(branches)<1:
return high
low = 0
while low<=high:
mid = (low+high)>>1
cmp = branches[mid].compare_to(node)
if cmp<0:
low = mid + 1
elif cmp>0:
high = mid - 1
else:
return mid
return -(low+1)
# node.py 节点类
from base_node import BaseNode, Status
from array_tool import ArrayTool
class Node(BaseNode):
"""
字典树节点
"""
def __init__(self, key=None, status=None, value=None):
super().__init__(key, status, value)
def add_child(self, node):
add = False
if not self._children:
self._children = BaseNode()._children
index = ArrayTool.binary_search(self._children, node)
if index >= 0:
target = self._children[index]
if node.status==Status.UNDEFINED_0 and target.status != Status.NOT_WORD_1:
target.status = Status.NOT_WORD_1
target.value = None
add = True
elif node.status==Status.NOT_WORD_1 and target.status == Status.WORD_END_3:
target.status = Status.WORD_MIDDLE_2
elif node.status==Status.WORD_END_3:
if target.status!=Status.WORD_END_3:
target.status = Status.WORD_MIDDLE_2
if target.get_value() is None:
add = True
target.set_value(node.get_value())
else:
insert = -(index + 1)
self._children.insert(insert, node)
add = True
return add
def get_child(self, key):
if self._children is None:
return None
index = ArrayTool.binary_search(self._children, key)
if index<0:
return None
return self._children[index]
#bin_trie.py 首字散列其余二分字典树
from pyhanlp import JClass
from base_node import Status, BaseNode
from Node import Node
def char_hash(str):
"""
哈希函数,调用Java 字符哈希函数
:param str:
:param length:
:return:
"""
return abs((hash(str)))%(10**5)
# return JClass('java.lang.Character')(str).hashCode()
class BinTrie(BaseNode):
def __init__(self, map=None):
self._size = 0
# self._children = [BaseNode()] * (65535+1)
self._children = [BaseNode()] * 100000
self.status = Status.NOT_WORD_1
if map:
for key, value in map.items():
self.put(key, value)
def put(self, key, value):
if len(key)==0:
return
branch = self
for char in key[:-1]:
branch.add_child(Node(char, Status.NOT_WORD_1, None))
branch = branch.get_child(char)
if branch.add_child(Node(key[-1], Status.WORD_END_3, value)):
self._size += 1
def __setitem__(self, key, value):
self.put(key, value)
def __getitem__(self, item):
state = self
for i, char in enumerate(item):
if state is None:
return None
# 这个地方要理解一下,对于第一个字符是按hash获取到对应的node,然后后面是判断node的子树
state = state.get_child(char)
if state is None:
return None
if state.status != Status.WORD_END_3 and state.status != Status.WORD_MIDDLE_2:
return None
return state.value
def __contains__(self, item):
return self[item] is not None
def remove(self, key):
branch = self
for char in key[:-1]:
if branch is None:
return
branch = branch.get_child(char)
if branch is None:
return
if branch.add_child(Node(key[-1], Status.UNDEFINED_0, None)):
self._size -= 1
def add_child(self, node):
add = False
key = node.get_key()
hash_index = char_hash(key)
target = self.get_child(hash_index)
if target.status is None:
self._children[hash_index] = node
add = True
else:
if node.status == Status.UNDEFINED_0 and target.status != Status.NOT_WORD_1:
target.status = Status.NOT_WORD_1
add = True
elif node.status == Status.NOT_WORD_1 and target.status == Status.WORD_END_3:
target.status = Status.WORD_MIDDLE_2
elif node.status == Status.WORD_END_3:
if target.status == Status.NOT_WORD_1:
target.status = Status.WORD_MIDDLE_2
if target.get_value() is None:
add = True
target.set_value(node.get_value())
return add
def __sizeof__(self):
return self._size
def get_key(self):
return 0
def get_child(self, key):
if isinstance(key, str):
key = char_hash(key)
return self._children[key]
在强调一下,代码是参考Java实现的,所以有Java开发的风格。在Java中base_node是一个抽象类,然后会封装很多有用的通用方法。然后还设计node这个类。这本身的设计利用多态的特性有很好的效果。
代码里面有两个地方讲一下。
- node类和bin_trie类都会有get_child方法,对应的就是首字散列和其余二分的体现。
- 散列函数:散列函数里写了两个类型的散列函数,一个是直接调用Java的函数,一个是利用Python的散列函数,然后取余的做法,这种可能不太准确啊,只是一个简单的想法。
效果测试:
还是按照书上面的测试代码,使用咱们自己实现的bin_trie来加载词典,然后使用前面解释的切分算法,看看效果。但是结果可能和咱们想的不一样。
先上测试代码:
import time
from bin_trie import BinTrie
from nlp.ch02.forward_segment import forward_segment
from nlp.ch02.backward_segment import backward_segment
from nlp.ch02.bidirectional_segment import bidirectional_segment
from pyhanlp import *
def load_dictionary():
"""
加载Hanlp中的mini词库
:return:
"""
IOUtil = JClass('com.hankcs.hanlp.corpus.io.IOUtil')
path = HanLP.Config.CoreDictionaryPath.replace('.txt', '.mini.txt')
print(path)
dic = IOUtil.loadDictionary([path])
return dic
def evaluate_speed(segment, text, dic, pressure):
start = time.time()
for i in range(pressure):
segment(text, dic)
elapsed_time = time.time()-start
print("%s :%.2f 万字/秒" %(segment.__name__, len(text)*pressure/10000/elapsed_time))
def tranfer2dict(dic):
res = {}
for key in dic:
value = dic.get(key)
res[key] = value
return res
def test_list(list, pressure):
start = time.time()
for i in range(pressure):
if i in list:
pass
print(f"list search cost: {time.time()-start}")
def test_set(set_dic, pressure):
start = time.time()
for i in range(pressure):
if i in set_dic:
pass
print(f"set search cost: {time.time()-start}")
if __name__ == '__main__':
text = "江西鄱阳湖干枯,中国最大淡水湖变成大草原"
pressure = 10000
dic = load_dictionary()
list_dic = [i for i in range(60000)]
set_dic = set(list_dic)
test_list(list_dic, pressure)
test_set(set_dic,pressure)
start = time.time()
dict_dic = tranfer2dict(dic)
bin_tree_dic = BinTrie(dict_dic)
print(f"build tree cost: {time.time()-start}")
print(forward_segment(text, bin_tree_dic))
# evaluate_speed(forward_segment, text, bin_tree_dic, pressure)
# evaluate_speed(backward_segment, text, bin_tree_dic, pressure)
# evaluate_speed(bidirectional_segment, text, bin_tree_dic, pressure)
测试结果:
forward_segment :4.34 万字/秒
backward_segment :4.28 万字/秒
bidirectional_segment :2.21 万字/秒
是不是和咱们想的不一样,怎么使用自己设计的数据结构更慢了呢?书中不是说速度快一些吗?其实吧,这个速度是我把hash函数改成使用Python的hash函数,而不是调用Java的来的结果,其实按照调用Java的hash来说,速度还会更慢,大概是这个得四分之一。那么原因在哪呢?
Python变慢的原因:
其实这里得好好看一下对于这个切分算法来说,它核心的是不断查询切分的字符串是否在词典中。在切分算法那一节,load_dictionary函数返回的是set(dic.keySet()),也就是说返回的是set集合。
那么算法的快慢其实就是set集合和字典树查询的快慢问题了。对于Python来说set集合的查询速度是很快的,我的测试代码中其实有简单测试对比set和list查询的速度,结果如下:
list search cost: 0.605888843536377
set search cost: 0.0004830360412597656
大概速度的差了1200倍左右,当然这个测试可能不准确啊,那么为啥set会这么快啊?
set本身的底层实现是hash表实现的,理论上查找速度是O(1)的。dict是用来存储键值对结构的数据的,set其实也是存储的键值对,只是默认键和值是相同的。Python中的dict和set都是通过hash表和散列表来实现的。具体实现其实还有很多细节,特别是对于dict,这里就不多说了,后面有时间也写一下。(这里有一篇别的大佬写的,但是感觉还是没太底层,不过用来理解这个问题够用了。https://blog.csdn.net/siyue0211/article/details/80560783)
那么对于咱们自己的这个bin_tree查找速度如何呢?
前面散列查找O(1),后面是二分查找,速度是O(logN)的。所以就Python来说其实速度会比直接使用set慢。
那为啥Java会快了?
Java字典树使用的是TreeMap,使用的数据结构是红黑树,时间复杂度是O(logN),然后首字散列的二分查找树使用了散列表和二分查找,散列表的时间复杂度是O(1),后面是二分查找,速度是O(logN)的。所以其实bin_tree会快一些。
那为啥Java会比Python的快呢?
那如果说Python set时间复杂度低,那是不是Python实现的结果就应该比Java的快呢?显然不是,因为咱们Python实际上是C++实现的(说的CPython),其实还是会调用C++底层的数据结构,需要解释器解释执行。可能这么解释也不太准确,大概意思就是咱们Python语言执行确实会比Java慢。
3、前缀树的使用
前缀树简单点来说就是在字典树中,从根节点能访问到叶子节点A的话,那就能推论出根节点肯定能访问到节点A的父节点。我们一般说前缀很多都是说字符串,例如单词“自然语言”,那么对应的前缀“自然”、“自然语”都必须在树中存在。
其实从算法的角度考虑,其实前缀树就是会增加了对树遍历的判断条件,不再是遍历后面全部的元素。
- 前缀全切分:全切分有个地方需要注意就是每次其实是遍历一个字符,不管后面组成的词包含多少个字符。
- 前缀正向最大切分:判断是否是单词的边界条件是当前节点的value值是否存在,存在就是一个词,然后如何判断是最大的呢?那就是一直按照树遍历,然后每次value都有值的时候,就把他保存下来,这样value这个变量就一直是最长的那个词。这里要先判断节点的value是否存在,再赋值,不能直接赋值,避免出现value为None的情况。例如对于“最大”这个词,其实正向切分会切成【“最”,“大”】,但是词典中有“最大值”这个词。所以“大”这个节点的value是为None的。还有注意下一次遍历是从上一次切分的最大词后一位开始的,所以需要记录上一次词结束的位置。
代码实现:
是在上面前缀树的代码上实现的,修改了base_node.py和bin_tree.py文件
# base_node.py 添加下方法
def transition(self, text, begin=0):
cur = self
for i in range(begin, len(text)):
cur = cur.get_child(text[i])
if cur is None or cur.status == Status.UNDEFINED_0:
return None
return cur
# bin_tree.py 添加如下方法:
def parse_text(self, text):
"""
前缀全切分
:param text:
:return:
"""
length = len(text)
begin = 0
state = self
res = []
i = 0
while i < length:
state = state.transition(text[i])
if state is not None:
value = state.get_value()
if value is not None:
res.append(text[begin:i + 1])
else:
i = begin
begin += 1
state = self
i += 1
return res
def parse_longest_text(self, text):
"""
前缀正向最大匹配
:param text:
:return:
"""
length = len(text)
res = []
i = 0
while i < length:
state = self.transition(text[i])
if state is not None:
value = state.get_value()
# end指针指向的是单词的结尾索引,不包含最后那个字
to = i + 1
end = to
while to < length:
state = state.transition(text[to])
if state is None:
break
if state.get_value() is not None:
value = state.get_value()
end = to + 1
to += 1
if value is not None:
res.append(text[i:end])
i = end - 1
i += 1
return res
也是参考Java代码实现的,由于Java和Python的for循环还是有点区别的,还是改了一段时间才实现的,还是比较菜啊。简单测试了几个例子和之前的切分算法结果是一致的,代码应该没啥问题。
测试效果:
fully_segment :2.53 万字/秒
forward_segment :4.41 万字/秒
parse_text :18.99 万字/秒
parse_longest_text :30.65 万字/秒
如上所示,前缀树切分是后面两条记录,可以看到差不多速度提升了7倍多,还是很可观的。
4、总结
本次是针对字典树那一章节的学习,前后花的时间还是很长的,主要是对何晗大佬算法设计的理解。主要涉及是字典树、首字散列其余二分以及前缀树三个方面,花了很多时间查看大佬的Java代码,确实对Java没那么熟悉,阅读理解的比较慢。但是花时间来实现对应的Python代码还是很有必要的,对这些数据结构有了更多的理解。OK,明天开始双数组字典树的学习。