自然语言处理入门学习(二)--字典树

字典树

1、字典树

字典树:trie树,用树结构来描述词典。

  • 树状结构
  • 每条边代表一个字符,字符串是一条路径
  • 节点可以存储value
  • 单词对应的是路径

字典树相对于普通树的结构来说,就是类似带权重的树,只是权重不是数字,而是每一个字符。同时当节点value不为none时,表示从根节点到该节点的路径对应的字符串就是一个词。
实现方案:

  • 节点:就和二叉树类似,有个节点类,但是字典树节点需要保存子树和value.子树需要保存路径上的字符,显然用dict来实现。
  • 树:主要要实现查询功能,其实就是按照查询字符来遍历node dict来判断是否存在。删除就是把对应字符串的value置为none,修改也是修改value值。

代码实现:

class Node:
    def __init__(self, value):
        self._children = {}
        self._value = value

    def _add_child(self, char, value, overwrite=False):
        child = self._children.get(char)
        if child is None:
            child = Node(value)
            self._children[char] = child
        elif overwrite:
            child._value = value
        return child


class Trie(Node):
    def __init__(self):
        super().__init__(None)

    def __contains__(self, item):
        return self[item] is not None

    def __getitem__(self, item):
        state = self
        for char in item:
            state = state._children.get(char)
            if state is None:
                return None
        return state._value

    def __setitem__(self, key, value):
        state = self
        for i, char in enumerate(key):
            if i<len(key)-1:
                state = state._add_child(char, None, False)
            else:
                state = state._add_child(char, value, True)

if __name__ == '__main__':
    trie = Trie()
    # add
    trie['自然'] = 'nature'
    trie['自然人'] = 'human'
    trie['自然语言'] = 'language'
    trie['自语'] = 'talk to oneself'
    trie['入门'] = 'introduction'
    assert '自然' in trie

    #delete
    trie['自然'] = None

    #modify
    trie['自然语言处理'] = 'human language'

2、首字散列其余二分的字典树

首字散列其余二分的字典树:按照字面意思就能知道这个和上面的字典树的区别,它借用散列表来实现对第一个字符快速的查找,然后后面字符按照二分法查找。

实现方案:

原书中是提供了Java实现的代码路径,地址:https://github.com/hankcs/HanLP/blob/1.x/src/main/java/com/hankcs/hanlp/collection/trie/bintrie/BinTrie.java。

按照Java实现思路,我写了一下Python的实现。

  • 散列表:按照书中提到的说法,因为Python散列函数返回值是64位的,远超于单个汉字的数目,还有就是对于意思相近字得到的hash值偏离比较远。第二点理由其实我有点不太理解,即使hash值偏差大应该也没太大问题吧。那么对于散列表的实现就是借助hash函数和list来实现。hash函数要么借用Java的hash函数要么自己写一个简单的hash函数。补充一下,当前查找中文单字的数量:《汉语大字典》现收楷书单字60370个。还有为啥只对首字散列其实也比较好理解,汉语中二字词最多。
  • 二分查找:二分查找是对字符进行查找,需要对字符进行比较。其他的和普通的二分查找没有区别了。
  • 节点:节点比普通的字典树节点多了节点状态这个属性,如果是单纯实现这个数据结构应该是不需要这个属性。但是可以方便理解和记录。

代码实现:

##base_node.py

from enum import Enum


class BaseNode:
    def __init__(self, key=None, status=None, value=None):
        self._children = []
        self.key = key
        self.value = value
        self.status = status

    def get_key(self):
        return self.key

    def get_value(self):
        return self.value

    def set_value(self, value):
        self.value = value

    def compare_to(self, other):
        key = other.key if isinstance(other, BaseNode) else other
        if self.key > key:
            return 1
        elif self.key < key:
            return -1
        else:
            return 0


class Status(Enum):
    # 为序列值指定value值
    # 未指定,用于删除词条
    UNDEFINED_0 = 0
    # 不是词语的结尾
    NOT_WORD_1 = 1
    # 是个词语的结尾,并且还可以继续
    WORD_MIDDLE_2 = 2
    # 是个词语的结尾,并且没有继续
    WORD_END_3 = 3
## array_tool.py 工具类

class ArrayTool:
    @staticmethod
    def binary_search(branches, node):
        high = len(branches) - 1
        if len(branches)<1:
            return high
        low = 0
        while low<=high:
            mid = (low+high)>>1
            cmp = branches[mid].compare_to(node)
            if cmp<0:
                low = mid + 1
            elif cmp>0:
                high = mid - 1
            else:
                return mid
        return -(low+1)
# node.py 节点类

from base_node import BaseNode, Status
from array_tool import ArrayTool

class Node(BaseNode):
    """
    字典树节点
    """
    def __init__(self, key=None, status=None, value=None):
        super().__init__(key, status, value)

    def add_child(self, node):
        add = False
        if not self._children:
            self._children = BaseNode()._children
        index = ArrayTool.binary_search(self._children, node)
        if index >= 0:
            target = self._children[index]
            if node.status==Status.UNDEFINED_0 and target.status != Status.NOT_WORD_1:
                target.status = Status.NOT_WORD_1
                target.value = None
                add = True
            elif node.status==Status.NOT_WORD_1 and target.status == Status.WORD_END_3:
                target.status = Status.WORD_MIDDLE_2
            elif node.status==Status.WORD_END_3:
                if target.status!=Status.WORD_END_3:
                    target.status = Status.WORD_MIDDLE_2
                if target.get_value() is None:
                    add = True
                target.set_value(node.get_value())
        else:
            insert = -(index + 1)
            self._children.insert(insert, node)
            add = True
        return add

    def get_child(self, key):
        if self._children is None:
            return None
        index = ArrayTool.binary_search(self._children, key)
        if index<0:
            return None
        return self._children[index]

#bin_trie.py 首字散列其余二分字典树

from pyhanlp import JClass
from base_node import Status, BaseNode
from Node import Node


def char_hash(str):
    """
    哈希函数,调用Java 字符哈希函数
    :param str:
    :param length:
    :return:
    """
    return abs((hash(str)))%(10**5)
    # return JClass('java.lang.Character')(str).hashCode()

class BinTrie(BaseNode):
    def __init__(self, map=None):
        self._size = 0
        # self._children = [BaseNode()] * (65535+1)
        self._children = [BaseNode()] * 100000
        self.status = Status.NOT_WORD_1
        if map:
            for key, value in map.items():
                self.put(key, value)

    def put(self, key, value):
        if len(key)==0:
            return
        branch = self
        for char in key[:-1]:
            branch.add_child(Node(char, Status.NOT_WORD_1, None))
            branch = branch.get_child(char)
        if branch.add_child(Node(key[-1], Status.WORD_END_3, value)):
            self._size += 1

    def __setitem__(self, key, value):
        self.put(key, value)

    def __getitem__(self, item):
        state = self
        for i, char in enumerate(item):
            if state is None:
                return None
            # 这个地方要理解一下,对于第一个字符是按hash获取到对应的node,然后后面是判断node的子树
            state = state.get_child(char)
        if state is None:
            return None
        if state.status != Status.WORD_END_3 and state.status != Status.WORD_MIDDLE_2:
            return None
        return state.value

    def __contains__(self, item):
        return self[item] is not None

    def remove(self, key):
        branch = self
        for char in key[:-1]:
            if branch is None:
                return
            branch = branch.get_child(char)
        if branch is None:
            return
        if branch.add_child(Node(key[-1], Status.UNDEFINED_0, None)):
            self._size -= 1

    def add_child(self, node):
        add = False
        key = node.get_key()
        hash_index = char_hash(key)
        target = self.get_child(hash_index)
        if target.status is None:
            self._children[hash_index] = node
            add = True
        else:
            if node.status == Status.UNDEFINED_0 and target.status != Status.NOT_WORD_1:
                target.status = Status.NOT_WORD_1
                add = True
            elif node.status == Status.NOT_WORD_1 and target.status == Status.WORD_END_3:
                target.status = Status.WORD_MIDDLE_2
            elif node.status == Status.WORD_END_3:
                if target.status == Status.NOT_WORD_1:
                    target.status = Status.WORD_MIDDLE_2
                if target.get_value() is None:
                    add = True
                target.set_value(node.get_value())
        return add

    def __sizeof__(self):
        return self._size

    def get_key(self):
        return 0

    def get_child(self, key):
        if isinstance(key, str):
            key = char_hash(key)
        return self._children[key]

在强调一下,代码是参考Java实现的,所以有Java开发的风格。在Java中base_node是一个抽象类,然后会封装很多有用的通用方法。然后还设计node这个类。这本身的设计利用多态的特性有很好的效果。

代码里面有两个地方讲一下。

  • node类和bin_trie类都会有get_child方法,对应的就是首字散列和其余二分的体现。
  • 散列函数:散列函数里写了两个类型的散列函数,一个是直接调用Java的函数,一个是利用Python的散列函数,然后取余的做法,这种可能不太准确啊,只是一个简单的想法。

效果测试:

还是按照书上面的测试代码,使用咱们自己实现的bin_trie来加载词典,然后使用前面解释的切分算法,看看效果。但是结果可能和咱们想的不一样。

先上测试代码:

import time
from bin_trie import BinTrie
from nlp.ch02.forward_segment import forward_segment
from nlp.ch02.backward_segment import backward_segment
from nlp.ch02.bidirectional_segment import bidirectional_segment

from pyhanlp import *

def load_dictionary():
    """
    加载Hanlp中的mini词库
    :return:
    """
    IOUtil = JClass('com.hankcs.hanlp.corpus.io.IOUtil')
    path = HanLP.Config.CoreDictionaryPath.replace('.txt', '.mini.txt')
    print(path)
    dic = IOUtil.loadDictionary([path])
    return dic


def evaluate_speed(segment, text, dic, pressure):
    start = time.time()
    for i in range(pressure):
        segment(text, dic)
    elapsed_time = time.time()-start
    print("%s :%.2f 万字/秒" %(segment.__name__, len(text)*pressure/10000/elapsed_time))

def tranfer2dict(dic):
    res = {}
    for key in dic:
        value = dic.get(key)
        res[key] = value
    return res

def test_list(list, pressure):
    start = time.time()
    for i in range(pressure):
        if i in list:
            pass
    print(f"list search cost: {time.time()-start}")

def test_set(set_dic, pressure):
    start = time.time()
    for i in range(pressure):
        if i in set_dic:
            pass
    print(f"set search cost: {time.time()-start}")


if __name__ == '__main__':
    text = "江西鄱阳湖干枯,中国最大淡水湖变成大草原"
    pressure = 10000
    dic = load_dictionary()
    list_dic = [i for i in range(60000)]
    set_dic = set(list_dic)
    test_list(list_dic, pressure)
    test_set(set_dic,pressure)
    start = time.time()
    dict_dic = tranfer2dict(dic)
    bin_tree_dic = BinTrie(dict_dic)
    print(f"build tree cost: {time.time()-start}")
    print(forward_segment(text, bin_tree_dic))

    # evaluate_speed(forward_segment, text, bin_tree_dic, pressure)
    # evaluate_speed(backward_segment, text, bin_tree_dic, pressure)
    # evaluate_speed(bidirectional_segment, text, bin_tree_dic, pressure)

测试结果:

forward_segment :4.34 万字/秒
backward_segment :4.28 万字/秒
bidirectional_segment :2.21 万字/秒

是不是和咱们想的不一样,怎么使用自己设计的数据结构更慢了呢?书中不是说速度快一些吗?其实吧,这个速度是我把hash函数改成使用Python的hash函数,而不是调用Java的来的结果,其实按照调用Java的hash来说,速度还会更慢,大概是这个得四分之一。那么原因在哪呢?

Python变慢的原因:

其实这里得好好看一下对于这个切分算法来说,它核心的是不断查询切分的字符串是否在词典中。在切分算法那一节,load_dictionary函数返回的是set(dic.keySet()),也就是说返回的是set集合。

那么算法的快慢其实就是set集合和字典树查询的快慢问题了。对于Python来说set集合的查询速度是很快的,我的测试代码中其实有简单测试对比set和list查询的速度,结果如下:

list search cost: 0.605888843536377
set search cost: 0.0004830360412597656

大概速度的差了1200倍左右,当然这个测试可能不准确啊,那么为啥set会这么快啊?

set本身的底层实现是hash表实现的,理论上查找速度是O(1)的。dict是用来存储键值对结构的数据的,set其实也是存储的键值对,只是默认键和值是相同的。Python中的dict和set都是通过hash表和散列表来实现的。具体实现其实还有很多细节,特别是对于dict,这里就不多说了,后面有时间也写一下。(这里有一篇别的大佬写的,但是感觉还是没太底层,不过用来理解这个问题够用了。https://blog.csdn.net/siyue0211/article/details/80560783)

那么对于咱们自己的这个bin_tree查找速度如何呢?

前面散列查找O(1),后面是二分查找,速度是O(logN)的。所以就Python来说其实速度会比直接使用set慢。

那为啥Java会快了?

Java字典树使用的是TreeMap,使用的数据结构是红黑树,时间复杂度是O(logN),然后首字散列的二分查找树使用了散列表和二分查找,散列表的时间复杂度是O(1),后面是二分查找,速度是O(logN)的。所以其实bin_tree会快一些。

那为啥Java会比Python的快呢?

那如果说Python set时间复杂度低,那是不是Python实现的结果就应该比Java的快呢?显然不是,因为咱们Python实际上是C++实现的(说的CPython),其实还是会调用C++底层的数据结构,需要解释器解释执行。可能这么解释也不太准确,大概意思就是咱们Python语言执行确实会比Java慢。

3、前缀树的使用

前缀树简单点来说就是在字典树中,从根节点能访问到叶子节点A的话,那就能推论出根节点肯定能访问到节点A的父节点。我们一般说前缀很多都是说字符串,例如单词“自然语言”,那么对应的前缀“自然”、“自然语”都必须在树中存在。

其实从算法的角度考虑,其实前缀树就是会增加了对树遍历的判断条件,不再是遍历后面全部的元素。

  • 前缀全切分:全切分有个地方需要注意就是每次其实是遍历一个字符,不管后面组成的词包含多少个字符。
  • 前缀正向最大切分:判断是否是单词的边界条件是当前节点的value值是否存在,存在就是一个词,然后如何判断是最大的呢?那就是一直按照树遍历,然后每次value都有值的时候,就把他保存下来,这样value这个变量就一直是最长的那个词。这里要先判断节点的value是否存在,再赋值,不能直接赋值,避免出现value为None的情况。例如对于“最大”这个词,其实正向切分会切成【“最”,“大”】,但是词典中有“最大值”这个词。所以“大”这个节点的value是为None的。还有注意下一次遍历是从上一次切分的最大词后一位开始的,所以需要记录上一次词结束的位置。

代码实现:

是在上面前缀树的代码上实现的,修改了base_node.py和bin_tree.py文件

# base_node.py 添加下方法

    def transition(self, text, begin=0):
        cur = self
        for i in range(begin, len(text)):
            cur = cur.get_child(text[i])
            if cur is None or cur.status == Status.UNDEFINED_0:
                return None
        return cur
# bin_tree.py 添加如下方法:

    def parse_text(self, text):
        """
        前缀全切分
        :param text:
        :return:
        """
        length = len(text)
        begin = 0
        state = self
        res = []
        i = 0

        while i < length:
            state = state.transition(text[i])
            if state is not None:
                value = state.get_value()
                if value is not None:
                    res.append(text[begin:i + 1])
            else:
                i = begin
                begin += 1
                state = self
            i += 1
        return res

    def parse_longest_text(self, text):
        """
        前缀正向最大匹配
        :param text:
        :return:
        """
        length = len(text)
        res = []
        i = 0

        while i < length:
            state = self.transition(text[i])
            if state is not None:
                value = state.get_value()
                # end指针指向的是单词的结尾索引,不包含最后那个字
                to = i + 1
                end = to
                while to < length:
                    state = state.transition(text[to])
                    if state is None:
                        break
                    if state.get_value() is not None:
                        value = state.get_value()
                        end = to + 1
                    to += 1
                if value is not None:
                    res.append(text[i:end])
                    i = end - 1

            i += 1
        return res

也是参考Java代码实现的,由于Java和Python的for循环还是有点区别的,还是改了一段时间才实现的,还是比较菜啊。简单测试了几个例子和之前的切分算法结果是一致的,代码应该没啥问题。

测试效果:

fully_segment :2.53 万字/秒
forward_segment :4.41 万字/秒
parse_text :18.99 万字/秒
parse_longest_text :30.65 万字/秒

如上所示,前缀树切分是后面两条记录,可以看到差不多速度提升了7倍多,还是很可观的。

4、总结

本次是针对字典树那一章节的学习,前后花的时间还是很长的,主要是对何晗大佬算法设计的理解。主要涉及是字典树、首字散列其余二分以及前缀树三个方面,花了很多时间查看大佬的Java代码,确实对Java没那么熟悉,阅读理解的比较慢。但是花时间来实现对应的Python代码还是很有必要的,对这些数据结构有了更多的理解。OK,明天开始双数组字典树的学习。

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值