什么是DoubleArrayTrie
双数组字典树(Double-Array Trie)是一种高效的字符串检索数据结构,它结合了Trie树的快速查找特性和数组的紧凑存储优势。这种数据结构主要用于实现高速字符串匹配,特别适合处理大规模字典和自然语言处理任务。
双数组字典树的核心思想是使用两个数组来表示Trie树:
- base数组: 用于存储每个节点的基础地址。
- check数组: 用于验证转移是否有效。
双数组字典树的主要特点包括:
- 空间效率: 相比传统Trie树,它大大减少了内存使用。
- 查询速度快: 利用数组的随机访问特性,实现了O(m)的查询复杂度,其中m为待查询字符串的长度。
- 构建复杂: 虽然查询效率高,但构建过程较为复杂,需要解决冲突和优化数组空间。
- 适用于静态词典: 一旦构建完成,不易进行动态更新。
- 广泛应用: 在自然语言处理、信息检索、拼写检查等领域有广泛应用。
基本概念解释
状态(State):
在DoubleArrayTrie中,状态代表Trie树中的一个节点。
每个状态通常用一个整数表示,这个整数是base和check数组的索引。
状态0通常表示根节点。
base数组:
base[s]存储的是状态s的"基础值"。
它用于计算下一个状态的位置。
base值的设计是DoubleArrayTrie构建过程中最关键的部分。
check数组:
check[t]存储的是状态t的父状态。
它用于验证状态转移的正确性。
如果check[t] == s,则说明从状态s到状态t的转移是有效的。
状态转移过程:
假设当前状态为s,要转移的字符为c(通常用字符的ASCII码或者预定义的编码值):
计算下一个可能的状态:t = base[s] + c
验证转移的正确性:检查check[t] == s是否成立
如果验证通过,新的状态就是t;否则,转移失败
通过以上描述,我们发现,双数组字典树物理存储和逻辑上和普通Trie都不太一样,物理存储上,它是两个int数组;逻辑上,普通Trie的节点上会存储字符Value,以及指向子树(或者叫子节点)的指针,而双数组字典数的字符,则是节点和子节点的边(从DFA的角度讲,它是转移条件)。
DoubleArrayTrie中的check数组可以理解为“子节点”指向“父节点”的指针;base数组则和转移条件(即当前字符)一起,用于从前到后计算下一个节点的位置;所以,对于双数组字典数来说,base和check数组的索引以及对应的值都是很有用的信息。
几个重要过程的理解
字符集的映射
可以用ASCII码、UTF-8的码,但这样数据空间要求有点大,如果你的字典字符空间不大,可以自定义字符集的映射。
DAT的构建
先理解原理,再来谈构建。如上文和上图所说,DAT其实是“状态”(base数组)通过转移条件(当前字符)转移到下一个状态的过程;转移后,下一个状态的check数组记录上一个状态的位置(所以上文中称其为“指针”)。构建的过程,就是实践这个的过程,只不过需要工程化处理很多问题,例如数组的扩容、base数组值的确定(base数组value的冲突解决办法)等。
如何查找
对于精确匹配,按照t = base[s] + c,检查check[t] == s是否成立一直往下走就行;对于前缀匹配,则先找到前缀串,再遍历子树。
如何存储字符串对应的Value(即K-V格式数据存储)
可以在叶子节点上存储Value的索引,这样,不管你外部是存List还是Array,通过索引直接访问即可。
示例代码
可参考此代码 https://github.com/komiya-atsushi/darts-java/blob/master/src/main/java/darts/DoubleArrayTrie.java
但你会发现和我们上面说的“貌似”稍稍不太一样,有兴趣的朋友可以自己debug代码试一试。这里给出一个我自己实现的demo,没有resize的过程,查找base数组value的过程也还需要改进,但是可以帮助大家更好的理解DAT的构建和查询过程。
package org.example;
import java.util.*;
/**
* @author JerryX
*/
public class DAT {
//TODO resize
private int[] base;
private int[] check;
private Map<Character, Integer> charCode;
private short depth;
private int currentUnUsedPos;
/**
* DAT构造方法
* @param words 词典列表必须去重,必须按字典序排序
*/
public DAT(List<String> words) {
// default 65536
base = new int[8 * 8 * 1024];
check = new int[8 * 8 * 1024];
Arrays.fill(check, -1);
// root
base[0] = 1;
currentUnUsedPos = 1;
// generate charCode and find depth
extractCharCode(words);
build(words);
}
public int exactMatch(String word) {
int res = -1;
int state = 0;
for (int i = 0; i < word.length(); i++) {
Character c = word.charAt(i);
Integer cCode = charCode.get(c);
if (null == cCode) {
System.out.println("no char " + c + " found in char code");
return res;
}
int nextState = base[state] + cCode;
if (check[nextState] != state) {
return res;
}
state = nextState;
}
int p = base[state];
if (check[p] == state && base[p] < 0 ) {
return -base[p] - 1;
}
return res;
}
public List<String> prefixMatch(String prefix) {
List<String> resList = new LinkedList<>();
int state = 0;
for (int i = 0; i < prefix.length(); i++) {
Character c = prefix.charAt(i);
Integer cCode = charCode.get(c);
if (null == cCode) {
System.out.println("no char " + c + " found in char code");
return resList;
}
int nextState = base[state] + cCode;
if (check[nextState] != state) {
return resList;
}
state = nextState;
}
findStrKey(state, prefix, resList);
return resList;
}
private void findStrKey(int state, String prefix, List<String> results) {
int p = base[state];
if (check[p] == state && base[p] < 0) {
results.add(prefix);
}
for (Map.Entry<Character, Integer> e : charCode.entrySet()) {
Integer cCode = e.getValue();
Character c = e.getKey();
int nextState = base[state] + cCode;
if (check[nextState] == state) {
findStrKey(nextState, prefix + c, results);
}
}
}
/**
* DAT构造
* @param words 词典列表必须去重,必须按字典序排序
*/
protected void build(List<String> words) {
int state = 0;
_build(state, 0, words, 0, words.size());
}
private void _build(int state, int depth, List<String> words, int left, int right) {
// LinkedHashMap 可以默认按插入排序
Map<Integer, Map.Entry<Integer, Integer>> subTries = new LinkedHashMap<>();
for (int i = left;i < right;i++) {
String word = words.get(i);
if (word.length() <= depth) {
continue;
}
Character c = word.charAt(depth);
int cCode = charCode.get(c);
int nextState = base[state] + cCode;
if (check[nextState] != state) {
//转移不存在,需要创建
check[nextState] = state;
}
if (!subTries.containsKey(nextState)) {
subTries.put(nextState, new AbstractMap.SimpleEntry<>(i, i+1));
}
subTries.get(nextState).setValue(i+1);
}
for (Map.Entry<Integer, Map.Entry<Integer, Integer>> e : subTries.entrySet()) {
Integer currentStage = e.getKey();
Map.Entry<Integer, Integer> subWordRange = e.getValue();
int nextDepth = depth + 1;
int baseVal = currentUnUsedPos;
//TODO fix this magic number
while (baseVal < 10000) {
boolean found = true;
for (int i=subWordRange.getKey();i<subWordRange.getValue();i++) {
String word = words.get(i);
if (word.length() < nextDepth) {
continue;
}
int cCode;
if (word.length() > nextDepth) {
Character c = word.charAt(nextDepth);
cCode = charCode.get(c);
} else {
cCode = 0;
}
int nextState = baseVal + cCode;
if (check[nextState] != -1) {
found = false;
break;
}
}
if (found) {
base[currentStage] = baseVal;
break;
}
baseVal++;
}
currentUnUsedPos = baseVal + 1;
int cnt = 0;
for (int i=subWordRange.getKey();i<subWordRange.getValue();i++) {
String word = words.get(i);
if (word.length() == nextDepth) {
int endStrPos = base[currentStage];
check[endStrPos] = currentStage;
base[endStrPos] = -i - 1;
cnt += 1;
}
}
if (cnt < (subWordRange.getValue() - subWordRange.getKey())) {
_build(currentStage, nextDepth, words, subWordRange.getKey(), subWordRange.getValue());
}
}
}
private void extractCharCode(List<String> words) {
Set<Character> charSet = new LinkedHashSet<>();
for (String word : words) {
for (int i = 0; i < word.length(); i++) {
charSet.add(word.charAt(i));
}
if (word.length() > depth) {
depth = (short) word.length();
}
}
charCode = new HashMap<>((int) (charSet.size() / 0.75f) + 1, 1);
int idx = 1;
for (Character c : charSet) {
charCode.put(c, idx++);
}
}
public static void main(String[] args) {
List<String> words = new ArrayList<>();
words.add("中国");
words.add("中国人");
words.add("中华");
words.add("中华人民共和国");
//保证按字典序有序
words.sort(String::compareTo);
System.out.println(words);
DAT dat = new DAT(words);
System.out.println(dat.exactMatch("中"));
System.out.println(dat.exactMatch("中国"));
System.out.println(dat.exactMatch("中国人"));
System.out.println(dat.exactMatch("中华"));
System.out.println(dat.exactMatch("中华人民"));
System.out.println(dat.exactMatch("中华人民共和国"));
System.out.println(dat.prefixMatch("中国"));
System.out.println(dat.prefixMatch("中华"));
System.out.println(dat.prefixMatch("中"));
}
}