双数组字典树(DoubleArrayTrie)

什么是DoubleArrayTrie

双数组字典树(Double-Array Trie)是一种高效的字符串检索数据结构,它结合了Trie树的快速查找特性和数组的紧凑存储优势。这种数据结构主要用于实现高速字符串匹配,特别适合处理大规模字典和自然语言处理任务。

双数组字典树的核心思想是使用两个数组来表示Trie树:

  1. base数组: 用于存储每个节点的基础地址。
  2. check数组: 用于验证转移是否有效。

双数组字典树的主要特点包括:

  1. 空间效率: 相比传统Trie树,它大大减少了内存使用。
  2. 查询速度快: 利用数组的随机访问特性,实现了O(m)的查询复杂度,其中m为待查询字符串的长度。
  3. 构建复杂: 虽然查询效率高,但构建过程较为复杂,需要解决冲突和优化数组空间。
  4. 适用于静态词典: 一旦构建完成,不易进行动态更新。
  5. 广泛应用: 在自然语言处理、信息检索、拼写检查等领域有广泛应用。

基本概念解释

状态(State):

在DoubleArrayTrie中,状态代表Trie树中的一个节点。
每个状态通常用一个整数表示,这个整数是base和check数组的索引。
状态0通常表示根节点。


base数组:

base[s]存储的是状态s的"基础值"。
它用于计算下一个状态的位置。
base值的设计是DoubleArrayTrie构建过程中最关键的部分。


check数组:

check[t]存储的是状态t的父状态。
它用于验证状态转移的正确性。
如果check[t] == s,则说明从状态s到状态t的转移是有效的。

状态转移过程
假设当前状态为s,要转移的字符为c(通常用字符的ASCII码或者预定义的编码值):

计算下一个可能的状态:t = base[s] + c
验证转移的正确性:检查check[t] == s是否成立
如果验证通过,新的状态就是t;否则,转移失败

通过以上描述,我们发现,双数组字典树物理存储和逻辑上和普通Trie都不太一样,物理存储上,它是两个int数组;逻辑上,普通Trie的节点上会存储字符Value,以及指向子树(或者叫子节点)的指针,而双数组字典数的字符,则是节点和子节点的边(从DFA的角度讲,它是转移条件)。

 DoubleArrayTrie中的check数组可以理解为“子节点”指向“父节点”的指针;base数组则和转移条件(即当前字符)一起,用于从前到后计算下一个节点的位置;所以,对于双数组字典数来说,base和check数组的索引以及对应的值都是很有用的信息

几个重要过程的理解

字符集的映射

        可以用ASCII码、UTF-8的码,但这样数据空间要求有点大,如果你的字典字符空间不大,可以自定义字符集的映射。

DAT的构建

先理解原理,再来谈构建。如上文和上图所说,DAT其实是“状态”(base数组)通过转移条件(当前字符)转移到下一个状态的过程;转移后,下一个状态的check数组记录上一个状态的位置(所以上文中称其为“指针”)。构建的过程,就是实践这个的过程,只不过需要工程化处理很多问题,例如数组的扩容、base数组值的确定(base数组value的冲突解决办法)等。

如何查找

对于精确匹配,按照t = base[s] + c,检查check[t] == s是否成立一直往下走就行;对于前缀匹配,则先找到前缀串,再遍历子树。

如何存储字符串对应的Value(即K-V格式数据存储)

可以在叶子节点上存储Value的索引,这样,不管你外部是存List还是Array,通过索引直接访问即可。

示例代码

可参考此代码 https://github.com/komiya-atsushi/darts-java/blob/master/src/main/java/darts/DoubleArrayTrie.java

但你会发现和我们上面说的“貌似”稍稍不太一样,有兴趣的朋友可以自己debug代码试一试。这里给出一个我自己实现的demo,没有resize的过程,查找base数组value的过程也还需要改进,但是可以帮助大家更好的理解DAT的构建和查询过程。

package org.example;

import java.util.*;

/**
 * @author JerryX
 */
public class DAT {

    //TODO resize

    private int[] base;

    private int[] check;

    private Map<Character, Integer> charCode;

    private short depth;

    private int currentUnUsedPos;

    /**
     * DAT构造方法
     * @param words 词典列表必须去重,必须按字典序排序
     */
    public DAT(List<String> words) {
        // default 65536
        base = new int[8 * 8 * 1024];
        check = new int[8 * 8 * 1024];
        Arrays.fill(check, -1);
        // root
        base[0] = 1;
        currentUnUsedPos = 1;
        // generate charCode and find depth
        extractCharCode(words);

        build(words);
    }

    public int exactMatch(String word) {
        int res = -1;
        int state = 0;
        for (int i = 0; i < word.length(); i++) {
            Character c = word.charAt(i);
            Integer cCode = charCode.get(c);
            if (null == cCode) {
                System.out.println("no char " + c + " found in char code");
                return res;
            }
            int nextState = base[state] + cCode;
            if (check[nextState] != state) {
                return res;
            }
            state = nextState;
        }
        int p = base[state];
        if (check[p] == state && base[p] < 0 ) {
            return -base[p] - 1;
        }
        return res;
    }

    public List<String> prefixMatch(String prefix) {
        List<String> resList = new LinkedList<>();
        int state = 0;
        for (int i = 0; i < prefix.length(); i++) {
            Character c = prefix.charAt(i);
            Integer cCode = charCode.get(c);
            if (null == cCode) {
                System.out.println("no char " + c + " found in char code");
                return resList;
            }
            int nextState = base[state] + cCode;
            if (check[nextState] != state) {
                return resList;
            }
            state = nextState;
        }
        findStrKey(state, prefix, resList);
        return resList;
    }

    private void findStrKey(int state, String prefix, List<String> results) {
        int p = base[state];
        if (check[p] == state && base[p] < 0) {
            results.add(prefix);
        }
        for (Map.Entry<Character, Integer> e : charCode.entrySet()) {
            Integer cCode = e.getValue();
            Character c = e.getKey();
            int nextState = base[state] + cCode;
            if (check[nextState] == state) {
                findStrKey(nextState, prefix + c, results);
            }
        }
    }

    /**
     * DAT构造
     * @param words 词典列表必须去重,必须按字典序排序
     */
    protected void build(List<String> words) {
        int state = 0;
        _build(state, 0, words, 0, words.size());
    }

    private void _build(int state, int depth, List<String> words, int left, int right) {
        // LinkedHashMap 可以默认按插入排序
        Map<Integer, Map.Entry<Integer, Integer>> subTries = new LinkedHashMap<>();
        for (int i = left;i < right;i++) {
            String word = words.get(i);
            if (word.length() <= depth) {
                continue;
            }
            Character c = word.charAt(depth);
            int cCode = charCode.get(c);
            int nextState = base[state] + cCode;
            if (check[nextState] != state) {
                //转移不存在,需要创建
                check[nextState] = state;
            }
            if (!subTries.containsKey(nextState)) {
                subTries.put(nextState, new AbstractMap.SimpleEntry<>(i, i+1));
            }
            subTries.get(nextState).setValue(i+1);
        }
        for (Map.Entry<Integer, Map.Entry<Integer, Integer>> e : subTries.entrySet()) {
            Integer currentStage = e.getKey();
            Map.Entry<Integer, Integer> subWordRange = e.getValue();
            int nextDepth = depth + 1;
            int baseVal = currentUnUsedPos;
            //TODO fix this magic number
            while (baseVal < 10000) {
                boolean found = true;
                for (int i=subWordRange.getKey();i<subWordRange.getValue();i++) {
                    String word = words.get(i);
                    if (word.length() < nextDepth) {
                        continue;
                    }
                    int cCode;
                    if (word.length() > nextDepth) {
                        Character c = word.charAt(nextDepth);
                        cCode = charCode.get(c);
                    } else {
                        cCode = 0;
                    }
                    int nextState = baseVal + cCode;
                    if (check[nextState] != -1) {
                        found = false;
                        break;
                    }
                }
                if (found) {
                    base[currentStage] = baseVal;
                    break;
                }
                baseVal++;
            }
            currentUnUsedPos = baseVal + 1;

            int cnt = 0;
            for (int i=subWordRange.getKey();i<subWordRange.getValue();i++) {
                String word = words.get(i);
                if (word.length() == nextDepth) {
                    int endStrPos = base[currentStage];
                    check[endStrPos] = currentStage;
                    base[endStrPos] = -i - 1;
                    cnt += 1;
                }
            }
            if (cnt < (subWordRange.getValue() - subWordRange.getKey())) {
                _build(currentStage, nextDepth, words, subWordRange.getKey(), subWordRange.getValue());
            }
        }
    }

    private void extractCharCode(List<String> words) {
        Set<Character> charSet = new LinkedHashSet<>();
        for (String word : words) {
            for (int i = 0; i < word.length(); i++) {
                charSet.add(word.charAt(i));
            }
            if (word.length() > depth) {
                depth = (short) word.length();
            }
        }
        charCode = new HashMap<>((int) (charSet.size() / 0.75f) + 1, 1);
        int idx = 1;
        for (Character c : charSet) {
            charCode.put(c, idx++);
        }
    }

    public static void main(String[] args) {
        List<String> words = new ArrayList<>();
        words.add("中国");
        words.add("中国人");
        words.add("中华");
        words.add("中华人民共和国");
        //保证按字典序有序
        words.sort(String::compareTo);
        System.out.println(words);
        DAT dat = new DAT(words);

        System.out.println(dat.exactMatch("中"));
        System.out.println(dat.exactMatch("中国"));
        System.out.println(dat.exactMatch("中国人"));
        System.out.println(dat.exactMatch("中华"));
        System.out.println(dat.exactMatch("中华人民"));
        System.out.println(dat.exactMatch("中华人民共和国"));

        System.out.println(dat.prefixMatch("中国"));
        System.out.println(dat.prefixMatch("中华"));
        System.out.println(dat.prefixMatch("中"));
    }

}

  • 12
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值