前言
有时我们在想《字典树》是不是也能引入快速失败《字符串模式匹配算法KMP》的方式,从而加快词语匹配。它就是AC自动机算法。
AC自动机算法
首先我们使用she , he , hers, his 等单词,使用AC自动机算法构建一棵字典树。
与普通的字典树的区别是多了一些虚线指针,为“Fail指针”。
验证fail指针
我们通过单词“ushers” 对其验证。
- 输入‘u’, 0节点下面没找到,失败,根据fail指针指向0结点, 因是0结点终止搜索
- 输入’s’, 前进至3结点
- 输入‘h’,前进至4结点
- 输入‘e’,前进至5结点, 5节点是终结点收获{”he“,”she“}
- 输入‘r’,5节点下面没找到,失败,根据fail指针指向2结点,搜索2结点下面‘r’前进至8结点
- 输入‘s’,前进至9结点, 9节点是终结点收获{”hers“}
- 因此得到单词 {”he“,”she“, ”hers“}
经过上面查询规则发现,所谓的失败指针,就是匹配不上时,查找最长公共后缀转过去
构建fail指针
再以上述词典树为例
-
创建一队列Q,把结点0(root)进行
-
结点0出队,发现是root,把结点0的孩子的失败指针都指向root,结点1,3的失败指针都结点0。再把结点1,3入队列Q
-
结点1出队,根据路径e,把结点2入队列Q. 然后假装“路径e”失败,查询当前fail指针下结点0的“路径e”. 没找到且发现是root,结点2的失败指针都结点0
-
结点3出队,根据路径h,把结点4入队列Q. 然后假装“路径h”失败,查询当前fail指针下结点0的“路径h”. 找到结点1,结点4的失败指针都结点1
-
结点2出队,根据路径r,把结点8入队列Q. 然后假装“路径r”失败,查询当前fail指针下结点0的“路径r”. 没找到且发现是root,结点8的失败指针都结点0
-
结点4出队,根据路径e,把结点5入队列Q. 然后假装“路径e”失败,查询当前fail指针下结点1的“路径e”. 找到结点2,结点5的失败指针都结点2, 发现2结点是终节节点,并把他的词加入自己结点中
-
以至类推,重复上述步骤,就构造出示图中的fail指针
class StreamChecker {
ACNode root;
ACNode p;
public StreamChecker(String[] words) {
// 构造字典树
this.root = new ACNode(' ');
this.p = root;
for (String word : words) {
ACNode temp = root;
for (char c : word.toCharArray()) {
int idx = c - 'a';
if (temp.children[idx] == null) temp.children[idx] = new ACNode(c);
temp = temp.children[idx];
}
temp.isEnding = true;
temp.length = word.length();
}
// 维护失败指针
buildFailPointer();
}
private void buildFailPointer() {
Queue<ACNode> queue = new LinkedList<>();
queue.add(root);
while (!queue.isEmpty()) {
ACNode p = queue.poll();
for (int i = 0; i < 26; i++) {
ACNode pc = p.children[i];
if (pc == null) continue;
if (p == root) pc.fail = root;
else {
ACNode q = p.fail;
while (q != null && q.children[i] == null) {
q = q.fail;
}
if (q == null) pc.fail = root;
else pc.fail = q.children[i];
}
queue.add(pc);
}
}
}
public boolean query(char letter) {
int idx = letter - 'a';
while (this.p != root && p.children[idx] == null) {
p = p.fail;
}
p = p.children[idx];
if (p == null) p = root;
ACNode temp = p;
while (temp != root) {
if (temp.isEnding) return true;
temp = temp.fail;
}
return false;
}
}
class ACNode {
char c;
boolean isEnding;
int length = -1;
ACNode[] children = new ACNode[26];
ACNode fail;
ACNode(char c) {
this.c = c;
}
}
双数组AC自动机算法
当然在《双数组字典树DoubleArrayTrie》也研究了"双数组ac自动机".
- 引出fail数组表示失败指针
- 引入output 表示ac的字符集
对she , he , hers, his的数组表为
字符值: [e-1,h-2,i-3,r-4,s-5]
字符 | - | - | e | h | i | h | s | r | s | e | s |
---|---|---|---|---|---|---|---|---|---|---|---|
pos | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 |
base | 1 | 0 | -3 | 1 | 3 | 8 | 3 | 5 | -3 | -8 | -5 |
check | -1 | -1 | 3 | 0 | 3 | 6 | 0 | 2 | 4 | 5 | 7 |
output | - | - | he | - | - | - | - | - | his | she he | hers |
fail | -1 | -1 | 0 | 0 | 0 | 3 | 0 | 0 | 6 | 2 | 6 |
public class AcDoubleArrayTrie {
String[] keys;// 字符集
int[] base;// 转移数组
int[] check;// 较验数组
int fail[]; //fail表
String[][] output;//输出表
private static class Node {
private int code;// 字符编码
private int s;// 父字符位置
@Override
public boolean equals(Object o) {
if (this == o)
return true;
if (o == null || getClass() != o.getClass())
return false;
Node node = (Node) o;
if (code != node.code)
return false;
return s == node.s;
}
@Override
public int hashCode() {
int result = code;
result = 31 * result + s;
return result;
}
}
public void build(List<String> list) {
// 给所有字符定编码
this.keys = list.stream().map(word -> word.split("")).flatMap(Arrays::stream).distinct().sorted()
.collect(Collectors.toList()).toArray(new String[0]);
base = new int[3 * keys.length];
check = new int[3 * keys.length];
fail = new int[3* keys.length];
output = new String[3 * keys.length][];
String[] dir = list.toArray(new String[0]);
// 设置root
base[0] = 1;
for (int i = 0; i < check.length ; i++) {
check[i] = -1;
}
for (int i = 0; i < fail.length ; i++) {
fail[i] = -1;
}
// 词的深度
int depth = 1;
while (!list.isEmpty()) {
// 根据相同前缀分组
Map<Integer, List<Node>> map = new HashMap<>();
for (int i = 0; i < list.size();) {
String word = list.get(i);
String pre = word.substring(0, depth - 1);
String k = word.substring(depth - 1, depth);
Node n = new Node();
n.code = findIndex(k);
n.s = depth == 1 ? 0 : indexOf(pre);
if (depth == word.length()) {
list.remove(i);
} else {
i++;
}
List<Node> siblings = map.getOrDefault(n.s, new ArrayList<>());
if(siblings.contains(n)){
continue;
}
siblings.add(n);
map.put(n.s, siblings);
}
map.forEach((s, siblings) -> {
int offset = 0;
for (int i = 0; i < siblings.size(); i++) {
Node node = siblings.get(i);
int c = node.code;
int t = base[s] + offset + c;
// 发现在节点已有值则偏移+1
if (check[t] != -1) {
offset++;
i = -1;
}
}
base[s] = base[s] + offset;
for (Node node : siblings) {
int c = node.code;
int t = base[s] + c;
// 给上父结点
check[t] = s;
// 给拿上一个节点偏移量
base[t] = base[s];
}
});
depth++;
}
// 发现字节点,置为负数
for (String aDir : dir) {
int s = indexOf(aDir);
base[s] = -1 * base[s];
output[s] = new String[]{aDir};
}
constructFail();
}
private void constructFail() {
Queue<Integer> queue = new LinkedBlockingDeque<>();
// 第一步,将深度为1的节点的failure设为根节点
for (int i = 0; i < check.length; i++) {
if (check[i] != 0) {
continue;
}
fail[i] = 0;
queue.add(i);
}
// 第二步,为深度 > 1 的节点建立failure表,这是一个bfs
while (!queue.isEmpty()) {
int currentIndex = queue.remove();
int current = base[currentIndex];
for (int target = 0; target < check.length; target++) {
if (check[target] != currentIndex) {
continue;
}
queue.add(target);
int code = target - (current < 0 ? -1 * current : current);
int currentFailIndex = fail[currentIndex];
while (true) {
if (check[base[currentFailIndex] + code] == currentFailIndex) {
fail[target] = base[currentFailIndex] + code;
constructOutput(target, base[currentFailIndex] + code);
break;
} else if (currentFailIndex == 0) {
fail[target] = 0;
break;
}
currentFailIndex = fail[currentFailIndex];
}
}
}
}
private void constructOutput(int target, int failTarget) {
if (output[target] == null || output[failTarget] == null) {
return;
}
List<String> result = Stream.of(output[target]).collect(Collectors.toList());
result.addAll(Stream.of(output[failTarget]).collect(Collectors.toList()));
output[target] = result.toArray(new String[0]);
}
// 找询字符编码
private int findIndex(String key) {
for (int i = 0; i < keys.length; i++) {
if (keys[i].equals(key))
return i + 1;
}
throw new RuntimeException("找不到[" + key + "]");
}
// 定位前缀结点position
private int indexOf(String pre) {
int s = 0;
String[] ss = pre.split("");
for (String word : ss) {
int c = findIndex(word);
s = (base[s] < 0 ? -1 * base[s] : base[s]) + c;
}
return s;
}
public List<String> parseText(String text){
List<String> result = new ArrayList<>();
int s = 0;
String[] ss = text.split("");
for (int i = 0; i < ss.length; i++) {
String word = ss[i];
int c;
try{
c= findIndex(word);
}catch (Exception e){
s = 0;
continue;
}
int t = (base[s] < 0 ? -1 * base[s] : base[s]) + c;
if (check[t] == s) {
if(output[t] !=null){
result.addAll(Stream.of(output[t]).collect(Collectors.toList()));
}
s = t;
} else {
if(fail[s] == -1){
continue;
}
s = fail[s];
i--;
}
}
return result;
}
public static void main(String[] args) {
AcDoubleArrayTrie adt = new AcDoubleArrayTrie();
List<String> list = Stream.of(new String[]{"hers", "his", "she", "he"}).collect(Collectors.toList());
// 构建DoubleArrayTrie
adt.build(list);
List<String> result = adt.parseText("ushers");
result.forEach(System.out::println);
}
}