insert 数组_DoubleArrayTrie(DAT)双数组字典树原理解读,golang/go语言实现

首先说明这不是一篇针对初学者的文章,如果你想了解字典树的概念请自行百度或者参考hanlp作者的文章,这里只关注一个那就是,双数组字典树,以下简称DAT。

最初接触到是CRF模型中有这样一个需求,那时候看了找了几篇文章看了还是懵懵懂懂,试着写了下,效率堪忧,最后还是用了darts实现,最近因为需要用到,于是想搞懂,就仔细看了darts的实现。有时候搜别人博客看来看去还不是如暴力点直接看代码来的清晰。。,下面就基于自己参考darts的实现来对原理做一个解释。

参考代码darts

1. DAT的构建

有以下词库:

0:一帆风顺
1:一流
2:了不起
3:了解
4:小心
5:小心谨慎

注意是排好顺序的,这也是DAT的适用场景(静态词库,查询性能要求高),我们期望得到如下DAT树:

              root
        /             
      一       了       小
     /       /        |
   帆   流    不  解     心
   ..............................【省略】

在DAT的实现,我们通过一个数组来保存这颗树,你可能想到,从root往下,每进行一次状态转移就跳到另一个节点,你可以理解为用数组存储二叉树,只不过二叉树中可以直接用索引计算得到左右节点,而这里比较复杂。 初始化两个数组,一个base,一个check,base用于存储offset值,也可以认为是一个状态值,而check数组仅仅用于保证状态转移过程中的正确性。

现在我们需要开始构建DAT: 初始数组中都为0,表示未被占用,DAT的构建顺序为DFS(深度优先),从p节点开始,(1) p初始为root,作为公共前缀,找到它的孩子节点为“一”,“了”,“小” (2)在数组中找到一个offset,使得:
check[offset+code(“一”)]= 0
check[offset+code(“了”)]= 0
check[offset+code(“小”)]= 0
也就是这个3个空间都未被占用:code("")就是基于字符的编码做简单的运算,比如DAT中就是编码+1
找到offset后,设置:
check[offset+code(“一”)]= offset
check[offset+code(“了”)]= offset
check[offset+code(“小”)]= offset
设置这个值,它用于搜索时的验证。如何找到这个offset,实际就是枚举尝试,不过DAT的实现中并不是总是从0开始枚举,它采用启发式的方法可以避免不必要的尝试,后面会讲到。(3)当找到offset后,就找到当前前缀转移到子节点的状态值,子节点转移到其他节点状态值何来?这就需要继续DFS下去,我们循环遍历 “一”,“了”,“小”:
设置(1)中的p=遍历到的节点
回到(1)(2),获取遍历到的节点转移到其子节点的offset
然后设置base值。例如:
当前遍历的是“一”,经历步骤(1)(2)后,拿到offset2
那么我们设置base[offset1 + code("一")] = offset2offset1是root节点转移到“一”的值,而offset2,是“一”转移到其子节点的值

这就是构建DAT的大致框架,也是核心代码,更详细的理解建议阅读代码,
已经详细注释,如果看不懂go,可以看java。代码中有一些细节问题,比如为什么offset要从1开始?为什么我们遇到叶子节点要base值存储的是-left-1?(这就是当前词在词库中的索引,关于left请细读源码)DAT中的启发式的体现?步骤(1)其实就是代码中的fetch

/*
对于输入的词语:
a
ab
abc
dd
dgh
hello
列的方向来看,同一列的字符处于trie树的同一层
找到parent的孩子节点,并界定它们的孩子节点搜索范围
eg:
对于
ab
abc
be
bfg
c
若当前parent = 'root'
那么结果[]*node将产生:
[['code': a,'d': 1,'l': 0,'r': 2] ['code': b,'d': 1,'l': 2,'r': 4] ['code': c,'d': 1,'l': 4,'r': 5]]
注:d:depth,l:left,r:right
参数:
    parent 父节点
返回:
    []*node: 孩子节点
*/
func (d *DoubleArrayTrie) fetch(parent *node) ([]*node, error);

步骤(2)就是代码中的insert

/*
 核心:
    对于输入的children,
    (1)若children为空,退出
    (2)否则,找到一个合适的offset,使得check[offset+c1] = check[offset+c2] = ...=0
    (3)DFS的继续构建树,依次对子节点构建c1,c2,c3,调用fetch,调用insert回到(1)
返回:
    int:当前找到的可用的offset
*/
func (d *DoubleArrayTrie) insert(children []*node) (int, error)

注释中对DAT中的状态转移描述

DAT中的有限状态自动机(DFA):
1.前置说明:
   code:
    字符编码
   trie:
    code构成的前缀树
   状态:
    存在一个offset,令其为状态s,它是一个数组索引的偏移量
2.状态转移base:
    给定一个单词"c[1]c[2]c[3]...c[n]",c[i]表示单词中第i个字符的编码
    转移方程为:
        base[s_[i] + c[i]] = s_[i+1]
        i = 0,s_[0] = 1,表示根节点的状态
        注意:s_[i+1]仅仅表示输入单词中c[i]转移到的特征,而不对应有限状态集合Q中的Q[i+1]
        有以下:
            s_[0] = 1
            base[0 + c[1]] = s_[1]
            base[s_[1]+c[2]] = s_[3]
            ...
        这个base就是用于存储状态转移的数组,它接受输入s_[i]和c[i],输出下一个状态s_[i+1]
3.check数组:
    有以下词库【有序】:
    a
    ab
    abc
    ad
    ba
    转换为dat树:
          root
         /    
        a     b
       /    /
      b  d  a
     /
    c
    容易知道,词库纵向来看,每一列就是树的一层,互为兄弟姐弟
    对于互为兄弟节点的c1,c2,c3,c4....cn (意味着它们有共同的前缀)
    若在check数组中,有check[s+c1] = check[s+c2] =check[s+c3] =...=0(为0表示该处位置未被占用)
       那么令check[s+c1] = check[s+c2] = check[s+c3] =...=s
    check相当于保存了c1,c2,c3...共同的前缀状态s
3.无论是check,还是base,它们都接受一个状态s1和一个code,从而产生一个新的状态s2,产生使用了索引,这也是高效的原因

2. DAT的搜索

如果对构建理解了,那么就搜索就没什么好说的,实际就是输入一个词,转换为字符数组,[a,b,c,d]后,利用字符编码和base数组进行状态转移的过程,当搜索到叶子节点就可以通过base值拿到词的索引,还记得之前说叶子节点base值赋值为-left-1吗?这就是它的作用体现

3.附上完整代码和测试

GitHub地址 完整版也附上,方便懒的同学...

package dat

import (
    "encoding/gob"
    "errors"
    "fmt"
    "log"
    "os"
    "reflect"
    "sort"
)

/*
算法何其复杂
DAT中的有限状态自动机(DFA):
1.前置说明:
   code:
    字符编码
   trie:
    code构成的前缀树
   状态:
    存在一个offset,令其为状态s,它是一个数组索引的偏移量
2.状态转移base:
    给定一个单词"c[1]c[2]c[3]...c[n]",c[i]表示单词中第i个字符的编码
    转移方程为:
        base[s_[i] + c[i]] = s_[i+1]
        i = 0,s_[0] = 1,表示根节点的状态
        注意:s_[i+1]仅仅表示输入单词中c[i]转移到的特征,而不对应有限状态集合Q中的Q[i+1]
        有以下:
            s_[0] = 1
            base[0 + c[1]] = s_[1]
            base[s_[1]+c[2]] = s_[3]
            ...
        这个base就是用于存储状态转移的数组,它接受输入s_[i]和c[i],输出下一个状态s_[i+1]
3.check数组:
    有以下词库【有序】:
    a
    ab
    abc
    ad
    ba
    转换为dat树:
          root
         /    
        a     b
       /    /
      b  d  a
     /
    c
    容易知道,词库纵向来看,每一列就是树的一层,互为兄弟姐弟
    对于互为兄弟节点的c1,c2,c3,c4....cn (意味着它们有共同的前缀)
    若在check数组中,有check[s+c1] = check[s+c2] =check[s+c3] =...=0(为0表示该处位置未被占用)
       那么令check[s+c1] = check[s+c2] = check[s+c3] =...=s
    check相当于保存了c1,c2,c3...共同的前缀状态s
3.无论是check,还是base,它们都接受一个状态s1和一个code,从而产生一个新的状态s2,产生使用了索引,这也是高效的原因
*/

// 构建trie树使用的节点
type node struct {
    //字符编码,这里的编码是unicode + 1,+1是略过root节点,
    // 若不+1,unicode = 0的字符将占用base[0],check[0]
    code  int
    depth int //所处树的层级,正好对应子节点在key中索引
    left  int // 当前字符在key list中搜索的左边界索引 (包括)
    right int // 当前字符在key list中搜索的右边界索引(不包括)
}

func (n *node) String() string {
    d := fmt.Sprint(n.depth)
    l := fmt.Sprint(n.left)
    r := fmt.Sprint(n.right)
    return "['code': " + string(rune(n.code-1)) + ",'d': " + d + ",'l': " + l + ",'r': " + r + "]"
}

const (
    // 初始化base大小
    INIT_SIZE = 65536 * 32
)

type key []rune
type DoubleArrayTrie struct {
    check        []int
    base         []int
    size         int         //对于base,check真正用到的大小
    allocSize    int         // 分配的数组大小
    keys         []key       // key list
    keySize      int         //key的数量
    values       interface{} //k-v中的v
    progress     int         // 构建进度,运行时非前缀key的数量
    nextCheckPos int         //下一次insert可能开始的检查位置
}

// 由于不想对外暴露DoubleArrayTrie的字段,但是gob协议中又需要编码
// 所以被迫这里使用一个中间结构来达到目的
type DATExport struct {
    Check        []int
    Base         []int
    Size         int
    AllocSize    int
    Keys         []key
    KeySize      int
    Values       interface{}
    Progress     int
    NextCheckPos int
}

func NewDoubleArrayTrie() *DoubleArrayTrie {
    return &DoubleArrayTrie{}
}

/*
    对base,used,check扩容
*/
func (d *DoubleArrayTrie) resize(newSize int) int {
    base2 := make([]int, newSize, newSize)
    check2 := make([]int, newSize, newSize)
    if d.allocSize > 0 {
        copy(base2, d.base)
        copy(check2, d.check)
    }
    d.base = base2
    d.check = check2
    d.allocSize = newSize
    return newSize
}

// 获取key的数量
func (d *DoubleArrayTrie) GetKeySize() int {
    return d.keySize
}

/*
对于输入的词语:
a
ab
abc
dd
dgh
hello
列的方向来看,同一列的字符处于trie树的同一层
找到parent的孩子节点,并界定它们的孩子节点搜索范围
eg:
对于
ab
abc
be
bfg
c
若当前parent = 'root'
那么结果[]*node将产生:
[['code': a,'d': 1,'l': 0,'r': 2] ['code': b,'d': 1,'l': 2,'r': 4] ['code': c,'d': 1,'l': 4,'r': 5]]
注:d:depth,l:left,r:right
参数:
    parent 父节点
返回:
    []*node: 孩子节点
*/
func (d *DoubleArrayTrie) fetch(parent *node) ([]*node, error) {
    // 搜索范围left->right
    // 搜索层:parent.depth
    var pre rune
    children := make([]*node, 0)
    for i := parent.left; i < parent.right; i++ {
        keyLen := len(d.keys[i])
        //"a,ab",parent='a'这种情况需要跳过 'a'
        if keyLen < parent.depth {
            continue
        }
        var cur rune = 0
        if keyLen != parent.depth {
            cur = d.keys[i][parent.depth] + 1
        }
        // 非字典序
        if pre > cur {
            msg := fmt.Sprintf("keys are not dict order, pre=%c, cur=%c,key=%s", pre-1, cur-1, string(d.keys[i]))
            return children, errors.New(msg)
        }
        // 遇到公共前缀依加一个NULL节点,保证搜索时能搜索到
        if cur == pre && len(children) > 0 {
            continue
        }
        newNode := new(node)
        newNode.left = i
        newNode.depth = parent.depth + 1
        newNode.code = int(cur)
        // 扫描到和上一个字符不重复,更新上一个字符的右边界
        if len(children) > 0 {
            children[len(children)-1].right = i
        }
        children = append(children, newNode)
        pre = cur
    }
    if len(children) > 0 {
        children[len(children)-1].right = parent.right
    }
    return children, nil
}

/*
 核心:
    对于输入的children,
    (1)若children为空,退出
    (2)否则,找到一个合适的offset,使得check[offset+c1] = check[offset+c2] = ...=0
    (3)DFS的继续构建树,依次对子节点构建c1,c2,c3,调用fetch,调用insert回到(1)
返回:
    int:当前找到的可用的offset
*/
func (d *DoubleArrayTrie) insert(children []*node) (int, error) {
    pos := 0
    //  启发式方法可以避免每次begin从0开始检测
    if d.nextCheckPos > children[0].code {
        pos = d.nextCheckPos - 1
    } else {
        pos = children[0].code
    }
    begin := 0 // 偏移量,>=1,base[0] = 1,第一个begin必定为1
    firstNonZero := true
    nonZeroNum := 0 // 非0的pos计数
outer:
    for {
        pos++
        begin = pos - children[0].code
        // 被占用
        if d.check[pos] != 0 {
            nonZeroNum++
            continue
        } else if firstNonZero {
            d.nextCheckPos = pos
            firstNonZero = false
        }
        // 扩容,最大长度,即当前偏移量+最大编码值+1已经大于等于了分配容量
        if s := begin + children[len(children)-1].code + 1; s > d.allocSize {
            rate := 0.0
            pr := float64(1.0 * d.keySize / (d.progress + 1))
            if 1.05 > pr {
                rate = 1.05
            } else {
                rate = pr
            }
            d.resize(int(float64(s) * rate))
        }
        for i := 1; i < len(children); i++ {
            if d.check[begin+children[i].code] != 0 {
                // 之前这里写的continue导致了bug
                continue outer
            }
        }
        break
    }
    //fmt.Println("find begin:", begin)
    if s := begin + children[len(children)-1].code + 1; d.size < s {
        d.size = s
    }
    // 简单的启发式方法:如果检查过的位置中95%都是被占用的(nextCheckPos(第一次开始有check=0)~pos(成功找到begin的一次的pos))
    // 那么设置nextCheckPos为检查结束的值,可能避免下一次insert从begin=0开始检测
    if s := pos - d.nextCheckPos + 1; float64(nonZeroNum*1.0/s) >= 0.95 {
        d.nextCheckPos = pos
    }
    for i := 0; i < len(children); i++ {
        d.check[begin+children[i].code] = begin
    }
    // 针对孩子节点继续递归构建
    for _, chi := range children {
        nodes, err := d.fetch(chi)
        if err != nil {
            log.Fatal(err)
            return begin, err
        }
        // 没有孩子节点
        if len(nodes) == 0 {
            // -1 是为了确保base值小于0
            // 当一个key是独立存在的,非前缀,其最后一个字符必是叶子节点,此时left=key的索引
            //通过状态转移拿到的base值可以还原为left,那么就可以索引到key,后面的exactMatch基于此搜索
            d.base[begin+chi.code] = -chi.left - 1
            // 到叶子节点用掉一个key(不包括公共前缀)
            d.progress++
        } else {
            nexState, err := d.insert(nodes)
            if err != nil {
                log.Fatal(err)
                return begin, err
            }
            // 状态转移
            d.base[begin+chi.code] = nexState
        }
    }
    return begin, nil
}

// 只用keys来build
func (d *DoubleArrayTrie) Build1(keys []string) error {
    return d.build_(keys, nil, false)
}

// 用key和value来build
// vals必须传入切片
func (d *DoubleArrayTrie) Build2(keys []string, vals interface{}) error {
    return d.build_(keys, vals, false)
}

// 用key来build,并且对key排序
func (d *DoubleArrayTrie) BuildWithSort(keys []string) error {
    return d.build_(keys, nil, true)
}

func (d *DoubleArrayTrie) build_(keys []string, vals interface{}, needSort bool) error {
    if vals != nil {
        typeOf := reflect.TypeOf(vals)
        if typeOf.Kind() != reflect.Slice {
            log.Fatalln("vals are not slice type")
            return errors.New("vals are not slice type")
        }
        d.values = vals
    }
    if needSort {
        sort.Strings(keys)
    }
    return d.build(keys)
}

// 底层构建方法
func (d *DoubleArrayTrie) build(keys []string) error {
    if len(keys) == 0 {
        log.Fatal("empty keys")
        return errors.New("empty keys")
    }
    keys2 := make([]key, 0, len(keys))
    for _, str := range keys {
        keys2 = append(keys2, []rune(str))
    }
    d.keys = keys2
    d.keySize = len(keys)
    // base,check默认初始大小
    d.resize(INIT_SIZE)
    d.nextCheckPos = 0
    root := new(node)
    root.left = 0
    root.right = d.keySize
    root.depth = 0
    children, err := d.fetch(root)
    if err != nil {
        return err
    }
    begin, err := d.insert(children)
    if err != nil {
        return err
    }
    log.Println("first begin = ", begin)
    d.base[0] = begin // 应该为1
    log.Println("build done...")
    log.Println("DAT:", d)
    // 压缩数组
    d.shrink()
    return nil
}

/*
 压缩base,check数组
 长度压缩为真实的size
*/
func (d *DoubleArrayTrie) shrink() {
    d.resize(d.size)
}

/*
    根据key返回value
    ...没有范型的无奈
*/
func (d *DoubleArrayTrie) GetValue(key string) interface{} {
    if d.values == nil {
        return nil
    }
    index, ok := d.IndexOf(key)
    if ok {
        valueOf := reflect.ValueOf(d.values)
        return valueOf.Index(index).Interface()
    }
    return nil
}

//返回一个key在数组的索引
//返回
// int: key在slice中的索引
// bool: 是否正常拿到
func (d *DoubleArrayTrie) IndexOf(key string) (int, bool) {
    return d.ExactMatchSearch(key)
}

//在DAT中搜索给定key,
//搜索成功:返回key所在的index
//为什么能拿到key的index?
//  当到达叶子节点时,它的base值赋值为了-left-1,这个left实际上就是当前key的index,
//  所以搜索时能够直接拿到index值
// 返回
//  int: 索引值
//      -1:出现key错误或者转移过程不满足check
//      -2:当前key作为了公共前缀,无法得到index
//      >=0: 当前key正常搜索到
// 为什么darts原代码多一次状态转移?【v1.1已解决】
//  对于根节点调用fetch,darts中因为pre==cur并不会跳过,所以依然产生一个节点放到children中,所以导致有一次多的状态转移
//   而我的实现,对于pre==cur是跳过的,不会产生NULL节点放到children中,所以不会多一次转移
//  【v1.2修复】不多转移一次会导致搜索不到公共前缀
func (d *DoubleArrayTrie) ExactMatchSearch(key string) (res int, ok bool) {
    if key == "" {
        return -1, false
    }
    chs := []rune(key)
    kLen := len(chs)
    begin := d.base[0]
    // root + code(a) -> s1, check[root + code[a]] = root
    // s1 + code[b] -> s2, check[s1 + code[b]] = s1
    for i := 0; i < kLen; i++ {
        // 状态转移函数的输入
        index := begin + int(chs[i]+1)
        if d.check[index] != begin {
            log.Fatalf("error transition, begin = %v, check[index]=%v,code=%cn", begin, d.check[index], chs[i])
            return -1, false
        }
        // 转移到下一个状态
        begin = d.base[index]
    }
    // NULL节点 code 为0
    index := begin + 0
    if d.check[index] != begin {
        log.Fatalf("can't trasfer to NULL node")
        return -1, false
    }
    begin = d.base[index]
    // 再转移一次
    if begin < 0 {
        // begin = -left -1
        // left = -begin -1
        return -begin - 1, true
    }
    return -2, false
}

func (d *DoubleArrayTrie) String() string {
    size := fmt.Sprint(d.size)
    alSize := fmt.Sprint(d.allocSize)
    return `
    [
        size : ` + size + `,
        allocSize : ` + alSize + `,
        keySize: ` + fmt.Sprint(d.keySize) + `,
        progress: ` + fmt.Sprint(d.progress) + `,   
    ]
           `
}

/*
原代码中就是搜索key锁包含的所有可能公共前缀
这里不实现
*/
func (d *DoubleArrayTrie) CommonPrefixSearch(key string) []string {
    subs := make([]string, 1)
    return subs
}

// 保存build好的DAT到指定路径
// 使用gob协议
func (d *DoubleArrayTrie) Store(path string) error {
    file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, os.ModePerm)
    if err != nil {
        log.Fatalln(err)
        return err
    }
    defer file.Close()
    encoder := gob.NewEncoder(file)
    dat := new(DATExport)

    dat.AllocSize = d.allocSize
    dat.Base = d.base
    dat.Check = d.check
    dat.Keys = d.keys
    dat.NextCheckPos = d.nextCheckPos
    dat.Progress = d.progress
    dat.Size = d.size
    dat.Values = d.values
    dat.KeySize = d.keySize

    err = encoder.Encode(dat)
    if err != nil {
        log.Println(err)
        return err
    }
    return nil
}

// 从指定路径加载DAT
func (d *DoubleArrayTrie) Load(path string) error {
    file, err := os.Open(path)
    if err != nil {
        log.Fatalln(err)
        return err
    }
    defer file.Close()
    decoder := gob.NewDecoder(file)
    dat := new(DATExport)
    err = decoder.Decode(dat)
    if err != nil {
        log.Fatalln(err)
        return err
    }

    d.allocSize = dat.AllocSize
    d.base = dat.Base
    d.check = dat.Check
    d.keys = dat.Keys
    d.nextCheckPos = dat.NextCheckPos
    d.progress = dat.Progress
    d.size = dat.Size
    d.values = dat.Values
    d.keySize = dat.KeySize

    return nil
}
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值