AC自动机算法

前言

有时我们在想《字典树》是不是也能引入快速失败《字符串模式匹配算法KMP》的方式,从而加快词语匹配。它就是AC自动机算法。

AC自动机算法

首先我们使用she , he , hers, his 等单词,使用AC自动机算法构建一棵字典树。
在这里插入图片描述
与普通的字典树的区别是多了一些虚线指针,为“Fail指针”。

验证fail指针

我们通过单词“ushers” 对其验证。

  1. 输入‘u’, 0节点下面没找到,失败,根据fail指针指向0结点, 因是0结点终止搜索
  2. 输入’s’, 前进至3结点
  3. 输入‘h’,前进至4结点
  4. 输入‘e’,前进至5结点, 5节点是终结点收获{”he“,”she“}
  5. 输入‘r’,5节点下面没找到,失败,根据fail指针指向2结点,搜索2结点下面‘r’前进至8结点
  6. 输入‘s’,前进至9结点, 9节点是终结点收获{”hers“}
  7. 因此得到单词 {”he“,”she“, ”hers“}

经过上面查询规则发现,所谓的失败指针,就是匹配不上时,查找最长公共后缀转过去

构建fail指针

再以上述词典树为例

  1. 创建一队列Q,把结点0(root)进行

  2. 结点0出队,发现是root,把结点0的孩子的失败指针都指向root,结点1,3的失败指针都结点0。再把结点1,3入队列Q

  3. 结点1出队,根据路径e,把结点2入队列Q. 然后假装“路径e”失败,查询当前fail指针下结点0的“路径e”. 没找到且发现是root,结点2的失败指针都结点0

  4. 结点3出队,根据路径h,把结点4入队列Q. 然后假装“路径h”失败,查询当前fail指针下结点0的“路径h”. 找到结点1,结点4的失败指针都结点1

  5. 结点2出队,根据路径r,把结点8入队列Q. 然后假装“路径r”失败,查询当前fail指针下结点0的“路径r”. 没找到且发现是root,结点8的失败指针都结点0

  6. 结点4出队,根据路径e,把结点5入队列Q. 然后假装“路径e”失败,查询当前fail指针下结点1的“路径e”. 找到结点2,结点5的失败指针都结点2, 发现2结点是终节节点,并把他的词加入自己结点中

  7. 以至类推,重复上述步骤,就构造出示图中的fail指针

class StreamChecker {
    ACNode root;
    ACNode p;

    public StreamChecker(String[] words) {
        // 构造字典树
        this.root = new ACNode(' ');
        this.p = root;
        for (String word : words) {
            ACNode temp = root;
            for (char c : word.toCharArray()) {
                int idx = c - 'a';
                if (temp.children[idx] == null) temp.children[idx] = new ACNode(c);
                temp = temp.children[idx];
            }
            temp.isEnding = true;
            temp.length = word.length();
        }
        // 维护失败指针
        buildFailPointer();
    }

    private void buildFailPointer() {
        Queue<ACNode> queue = new LinkedList<>();
        queue.add(root);
        while (!queue.isEmpty()) {
            ACNode p = queue.poll();
            for (int i = 0; i < 26; i++) {
                ACNode pc = p.children[i];
                if (pc == null) continue;

                if (p == root) pc.fail = root;
                else {
                    ACNode q = p.fail;
                    while (q != null && q.children[i] == null) {
                        q = q.fail;
                    }

                    if (q == null) pc.fail = root;
                    else pc.fail = q.children[i];
                }
                queue.add(pc);
            }
        }
    }
    
    public boolean query(char letter) {
        int idx = letter - 'a';
        while (this.p != root && p.children[idx] == null) {
            p = p.fail;
        }
        p = p.children[idx];
        if (p == null) p = root;

        ACNode temp = p;
        while (temp != root) {
            if (temp.isEnding) return true;
            temp = temp.fail;
        }

        return false;
    }
}

class ACNode {
    char c;
    boolean isEnding;
    int length = -1;

    ACNode[] children = new ACNode[26];
    ACNode fail;

    ACNode(char c) {
        this.c = c;
    }
}

双数组AC自动机算法

当然在《双数组字典树DoubleArrayTrie》也研究了"双数组ac自动机".

  1. 引出fail数组表示失败指针
  2. 引入output 表示ac的字符集

对she , he , hers, his的数组表为

字符值: [e-1,h-2,i-3,r-4,s-5]

字符--ehihsrses
pos012345678910
base10-313835-3-8-5
check-1-1303602457
output--he-----hisshe
he
hers
fail-1-1000300626
public class AcDoubleArrayTrie {
 
    String[] keys;// 字符集
    int[] base;// 转移数组
    int[] check;// 较验数组
    int fail[]; //fail表
    String[][] output;//输出表
 
    private static class Node {
 
        private int code;// 字符编码
 
        private int s;// 父字符位置

        @Override
        public boolean equals(Object o) {
            if (this == o)
                return true;
            if (o == null || getClass() != o.getClass())
                return false;

            Node node = (Node) o;

            if (code != node.code)
                return false;
            return s == node.s;
        }

        @Override
        public int hashCode() {
            int result = code;
            result = 31 * result + s;
            return result;
        }
    }
 
    public void build(List<String> list) {
 
        // 给所有字符定编码
        this.keys = list.stream().map(word -> word.split("")).flatMap(Arrays::stream).distinct().sorted()
                .collect(Collectors.toList()).toArray(new String[0]);
 
        base = new int[3 * keys.length];
        check = new int[3 * keys.length];
        fail = new int[3* keys.length];
        output = new String[3 * keys.length][];

        String[] dir = list.toArray(new String[0]);
 
        // 设置root
        base[0] = 1;
        for (int i = 0; i < check.length ; i++) {
            check[i] = -1;
        }
        for (int i = 0; i < fail.length ; i++) {
            fail[i] = -1;
        }
 
        // 词的深度
        int depth = 1;
 
        while (!list.isEmpty()) {
 
            // 根据相同前缀分组
            Map<Integer, List<Node>> map = new HashMap<>();
            for (int i = 0; i < list.size();) {
                String word = list.get(i);

                String pre = word.substring(0, depth - 1);
                String k = word.substring(depth - 1, depth);

                Node n = new Node();
                n.code = findIndex(k);
                n.s = depth == 1 ? 0 : indexOf(pre);
                if (depth == word.length()) {
                    list.remove(i);
                } else {
                    i++;
                }

                List<Node> siblings = map.getOrDefault(n.s, new ArrayList<>());

                if(siblings.contains(n)){
                    continue;
                }
                siblings.add(n);
                map.put(n.s, siblings);
            }
 
            map.forEach((s, siblings) -> {
                int offset = 0;

                for (int i = 0; i < siblings.size(); i++) {
                    Node node = siblings.get(i);
                    int c = node.code;
                    int t = base[s] + offset + c;

                    // 发现在节点已有值则偏移+1
                    if (check[t] != -1) {
                        offset++;
                        i = -1;
                    }
                }

                base[s] = base[s] + offset;

                for (Node node : siblings) {
                    int c = node.code;
                    int t = base[s] + c;
                    // 给上父结点
                    check[t] = s;
                    // 给拿上一个节点偏移量
                    base[t] = base[s];
                }
            });
            depth++;
        }
 
        // 发现字节点,置为负数
        for (String aDir : dir) {
            int s = indexOf(aDir);
            base[s] = -1 * base[s];
            output[s] = new String[]{aDir};
        }
        constructFail();
    }

    private void constructFail() {

        Queue<Integer> queue = new LinkedBlockingDeque<>();

        // 第一步,将深度为1的节点的failure设为根节点
        for (int i = 0; i < check.length; i++) {
            if (check[i] != 0) {
                continue;
            }
            fail[i] = 0;
            queue.add(i);
        }

        // 第二步,为深度 > 1 的节点建立failure表,这是一个bfs
        while (!queue.isEmpty()) {
            int currentIndex = queue.remove();
            int current = base[currentIndex];

            for (int target = 0; target < check.length; target++) {
                if (check[target] != currentIndex) {
                    continue;
                }
                queue.add(target);
                int code = target - (current < 0 ? -1 * current : current);

                int currentFailIndex = fail[currentIndex];

                while (true) {
                    if (check[base[currentFailIndex] + code] == currentFailIndex) {
                        fail[target] = base[currentFailIndex] + code;
                        constructOutput(target, base[currentFailIndex] + code);
                        break;
                    } else if (currentFailIndex == 0) {
                        fail[target] = 0;
                        break;
                    }
                    currentFailIndex = fail[currentFailIndex];
                }

            }
        }
    }

    private void constructOutput(int target, int failTarget) {
        if (output[target] == null || output[failTarget] == null) {
            return;
        }
        List<String> result = Stream.of(output[target]).collect(Collectors.toList());
        result.addAll(Stream.of(output[failTarget]).collect(Collectors.toList()));
        output[target] = result.toArray(new String[0]);
    }
 
    // 找询字符编码
    private int findIndex(String key) {
        for (int i = 0; i < keys.length; i++) {
            if (keys[i].equals(key))
                return i + 1;
        }
        throw new RuntimeException("找不到[" + key + "]");
    }
 
    // 定位前缀结点position
    private int indexOf(String pre) {
        int s = 0;
        String[] ss = pre.split("");
        for (String word : ss) {
            int c = findIndex(word);
            s = (base[s] < 0 ? -1 * base[s] : base[s]) + c;
        }
        return s;
    }

    public List<String> parseText(String text){
        List<String> result = new ArrayList<>();
        int s = 0;
        String[] ss = text.split("");

        for (int i = 0; i < ss.length; i++) {
            String word = ss[i];
            int c;
            try{
                c= findIndex(word);
            }catch (Exception e){
                s = 0;
                continue;
            }

            int t = (base[s] < 0 ? -1 * base[s] : base[s]) + c;

            if (check[t] == s) {
                if(output[t] !=null){
                    result.addAll(Stream.of(output[t]).collect(Collectors.toList()));
                }
                s = t;
            } else {
                if(fail[s] == -1){
                    continue;
                }
                s = fail[s];
                i--;
            }

        }
        return result;
    }

    public static void main(String[] args) {
        AcDoubleArrayTrie adt = new AcDoubleArrayTrie();
        List<String> list = Stream.of(new String[]{"hers", "his", "she", "he"}).collect(Collectors.toList());
 
        // 构建DoubleArrayTrie
        adt.build(list);

        List<String> result = adt.parseText("ushers");

        result.forEach(System.out::println);
    }
}

主要参考

AC自动机 - 关于Fail指针

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值