思路:
这道题的主题思路其实很简单,在给定字符串里面找到一个模式串(题目给的集合的字串),看其前后的字符的个数,分别加一,然后相乘就可以了,难点是要找到所有的字串和位置,暴力肯定不行,所以姐妹们就想到了AC自动机,确实,了解了AC自动机的机制之后其实就是一道送分题,在建树的时候把下标和长度记录一下,在字典树里面query的时候每找到一个字串就进行上述的操作,知道O(m + n )找完所有的字串,就能直接得出结果。
ps:这道题月月和花花很快就有思路了,超级超级厉害,然后AC自动机是我们第一次接触,大概花了一个小时的时间看了董晓老师的视频学习了这个算法,但是虽然理解了,但是还不能完全熟练的运用,还是需要多练习,多做相关的题目,板子是队长找到的板子。
真的被这个字串搜索的快速程度惊呆,太天才了!向Alfred V. Aho和Margaret J.Corasick先生致敬
#include <queue>
#include <cstdlib>
#include <cmath>
#include <cstdio>
#include <string>
#include <cstring>
#include <iostream>
#include <algorithm>
#include "map"
using namespace std;
typedef long long ll;
const int maxn = 2*1e6+9;
long long xx = 0;
long long len;
int trie[maxn][26]; //字典树
int cntword[maxn]; //记录该单词出现次数
int fail[maxn]; //失败时的回溯指针
int cnt = 0;
long long modd = 1e9 + 7;
map<int, int> mp;
void insertWords(string s){
int root = 0;
for(int i=0;i<s.size();i++){
int next = s[i] - 'a';
if(!trie[root][next])
trie[root][next] = ++cnt, mp[cnt] = i + 1;
root = trie[root][next];
}
cntword[root]++; //当前节点单词数+1
}
void getFail(){
queue <int>q;
for(int i=0;i<26;i++){ //将第二层所有出现了的字母扔进队列
if(trie[0][i]){
fail[trie[0][i]] = 0;
q.push(trie[0][i]);
}
}
//fail[now] ->当前节点now的失败指针指向的地方
//tire[now][i] -> 下一个字母为i+'a'的节点的下标为tire[now][i]
while(!q.empty()){
int now = q.front();
q.pop();
for(int i=0;i<26;i++){ //查询26个字母
if(trie[now][i]){
//如果有这个子节点为字母i+'a',则
//让这个节点的失败指针指向(((他父亲节点)的失败指针所指向的那个节点)的下一个节点)
//有点绕,为了方便理解特意加了括号
fail[trie[now][i]] = trie[fail[now]][i];
q.push(trie[now][i]);
}
else//否则就让当前节点的这个子节点
//指向当前节点fail指针的这个子节点
trie[now][i] = trie[fail[now]][i];
}
}
}
void query(string s){
int now = 0,ans = 0;
for(int i=0;i<s.size();i++){ //遍历文本串
now = trie[now][s[i]-'a']; //从s[i]点开始寻找
for(int j=now;j ;j=fail[j]){
//一直向下寻找,直到匹配失败(失败指针指向根或者当前节点已找过).
ans += cntword[j];
if(cntword[j] > 0) xx =(xx + (i - mp[j] + 2) * (len - i) ) % modd;
//cntword[j] = -1; //将遍历国后的节点标记,防止重复计算
}
}
// return ans;
}
int main() {
int n, m;
string s;
cin >> n >> m;
for(int i=0;i<n;i++){
cin >> s ;
insertWords(s);
}
fail[0] = 0;
getFail();
while(m --){
xx = 0;
cin >> s ;
len = s.length() ;
query(s) ;
cout<<xx<<endl;
}
return 0;
}