还是视频比较好懂,b站传送 传送门
优化在这 传送门
思想就是,不暴力跳fail边,转而存下来,最后一次性跳,从下往上跳,而且是个拓扑结构。
优秀的模板贴上面;
P3796 【模板】AC自动机(加强版)
#include<bits/stdc++.h>
#define il inline
#define pb push_back
#define ms(_data,v) memset(_data,v,sizeof(_data))
#define sc(n) scanf("%d",&n)
#define SC(n,m) scanf("%d %d",&n,&m)
#define SZ(a) int((a).size())
#define rep(i,a,b) for(int i=a;i<=b;++i)
#define drep(i,a,b) for(int i=a;i>=b;--i)
using namespace std;
typedef long long ll;
const ll inf=0x3f3f3f3f;
const double pi=acos(-1.0);
const double eps=1e-9;
const int maxn=155*75;
const int maxs=1e6+5;
//il int Add(int x,int y) {return x+y>=mod?x+y-mod:x+y;}
//il int Mul(ll x,int y) {return x*y>=mod?x*y%mod:x*y;}
int mp[maxn],res[maxn];
struct AHO {
struct node {
int son[26],fail,flag,ans;
} trie[maxn];
queue<int> q;
int cnt,in[maxn];
il void clear(int cnt) {
ms(trie[cnt].son,0);
trie[cnt].ans=trie[cnt].fail=trie[cnt].flag=0;
}
void insert(char*s ,int num) {
int root=1,len=strlen(s),id;
for(int i=0; i<len; ++i) {
id=s[i]-'a';
if(!trie[root].son[id]) {
trie[root].son[id]=++cnt;
clear(cnt);
}
root=trie[root].son[id];
}
if(!trie[root].flag) trie[root].flag=num;
mp[num]=trie[root].flag;
}
void getFail() {
for(int i=0; i<26; ++i) trie[0].son[i]=1;
q.push(1);
int u,ufail,v;
while(!q.empty()) {
u=q.front(),ufail=trie[u].fail,q.pop();
for(int i=0; i<26; ++i) {
v=trie[u].son[i];
if(!v) {
trie[u].son[i]=trie[ufail].son[i];
continue;
}
trie[v].fail=trie[ufail].son[i];
in[trie[v].fail]++,q.push(v);
}
}
}
void pushup() {
for(int i=1; i<=cnt; ++i) if(!in[i]) q.push(i);
int u,v;
while(!q.empty()) {
u=q.front(),q.pop();
res[trie[u].flag]=trie[u].ans;
v=trie[u].fail,in[v]--;
trie[v].ans+=trie[u].ans;
if(in[v]==0) q.push(v);
}
}
void match(char *s) {
int root=1,len=strlen(s),id;
for(int i=0; i<len; ++i) {
id=s[i]-'a';
root=trie[root].son[id],trie[root].ans++;
}
pushup();
}
void init() {
cnt=1,ms(trie[1].son,0);
trie[1].ans=trie[1].fail=trie[1].ans=0;
while(!q.empty()) q.pop();
}
}aho;
int T,n;
char a[155][75];
char s[maxs];
int main() {
std::ios::sync_with_stdio(0);
while(sc(n)!=EOF && n) {
aho.init();
for(int i=1; i<=n; ++i) {
scanf("%s",a[i]);
aho.insert(a[i],i);
}
aho.getFail();
scanf("%s",s);
aho.match(s);
int mx=-1;
for(int i=1; i<=n; ++i) {
if(res[mp[i]]>mx) mx=res[mp[i]];
}
printf("%d\n",mx);
for(int i=1; i<=n; ++i) {
if(res[mp[i]]==mx) printf("%s\n",a[i]);
}
}
return 0;
}
HDU 2222 Keywords Search
跟着up主写的
#include<bits/stdc++.h>
#define il inline
#define pb push_back
#define ms(_data,v) memset(_data,v,sizeof(_data))
#define sc(n) scanf("%d",&n)
#define SC(n,m) scanf("%d %d",&n,&m)
#define SZ(a) int((a).size())
#define rep(i,a,b) for(int i=a;i<=b;++i)
#define drep(i,a,b) for(int i=a;i>=b;--i)
using namespace std;
typedef long long ll;
const ll inf=0x3f3f3f3f;
const double pi=acos(-1.0);
const double eps=1e-9;
const int maxn=5e5+5;
const int maxs=1e6+5;
//il int Add(int x,int y) {return x+y>=mod?x+y-mod:x+y;}
//il int Mul(ll x,int y) {return x*y>=mod?x*y%mod:x*y;}
struct Aho{
struct node{
int nxt[26];
int fail,cnt;
}no[maxn];
int sz;
queue<int> q;
void init(){
while(!q.empty()) q.pop();
for(int i=0;i<maxn;++i){
ms(no[i].nxt,0);
no[i].fail=no[i].cnt=0;
}
sz=1;
}
void insert(char *s){
int len=strlen(s),root=0,id;
for(int i=0;i<len;++i){
id=s[i]-'a';
if(!no[root].nxt[id]){
no[root].nxt[id]=sz++;
}
root=no[root].nxt[id];
}
no[root].cnt++;
}
void build(){
no[0].fail=-1;
q.push(0);
while(!q.empty()){
int u=q.front();
q.pop();
for(int i=0;i<26;++i){
if(no[u].nxt[i]){
if(u==0) no[no[u].nxt[i]].fail=0;
else{
int v=no[u].fail;
while(v!=-1){
if(no[v].nxt[i]){
no[no[u].nxt[i]].fail=no[v].nxt[i];
break;
}
v=no[v].fail;
}
if(v==-1) no[no[u].nxt[i]].fail=0;
}
q.push(no[u].nxt[i]);
}
}
}
}
int get(int root){
int res=0;
while(root){
res+=no[root].cnt;
no[root].cnt=0;
root=no[root].fail;
}
return res;
}
int match(char *s){
int len=strlen(s),id,res=0,root=0;
for(int i=0;i<len;++i){
id=s[i]-'a';
if(no[root].nxt[id]) root=no[root].nxt[id];
else{
int p=no[root].fail;
while(p!=-1 && no[p].nxt[id]==0) p=no[p].fail;
if(p==-1) root=0;
else root=no[p].nxt[id];
}
if(no[root].cnt) res+=get(root);
}
return res;
}
}aho;
int T,n;
char s[maxs];
int main(){
std::ios::sync_with_stdio(0);
sc(T);
while(T--){
sc(n);
aho.init();
for(int i=1;i<=n;++i){
scanf("%s",s);
aho.insert(s);
}
aho.build();
scanf("%s",s);
printf("%d\n",aho.match(s));
}
return 0;
}
P3808 【模板】AC自动机(简单版)
#include<bits/stdc++.h>
#define il inline
#define pb push_back
#define ms(_data,v) memset(_data,v,sizeof(_data))
#define sc(n) scanf("%d",&n)
#define SC(n,m) scanf("%d %d",&n,&m)
#define SZ(a) int((a).size())
#define rep(i,a,b) for(int i=a;i<=b;++i)
#define drep(i,a,b) for(int i=a;i>=b;--i)
using namespace std;
typedef long long ll;
const ll inf=0x3f3f3f3f;
const double pi=acos(-1.0);
const double eps=1e-9;
const int maxn=1e6+5;
const int maxs=1e6+5;
//il int Add(int x,int y) {return x+y>=mod?x+y-mod:x+y;}
//il int Mul(ll x,int y) {return x*y>=mod?x*y%mod:x*y;}
struct Aho {
struct node {
int nxt[26];
int fail,cnt;
} no[maxn];
int sz=1;
queue<int> q;
void init() {
while(!q.empty()) q.pop();
for(int i=0; i<maxn; ++i) {
ms(no[i].nxt,0);
no[i].fail=no[i].cnt=0;
}
sz=1;
}
void insert(char *s) {
int len=strlen(s),root=0,id;
for(int i=0; i<len; ++i) {
id=s[i]-'a';
if(!no[root].nxt[id]) {
no[root].nxt[id]=sz++;
}
root=no[root].nxt[id];
}
no[root].cnt++;
}
void build() {
no[0].fail=-1;
q.push(0);
while(!q.empty()) {
int u=q.front();
q.pop();
for(int i=0; i<26; ++i) {
if(no[u].nxt[i]) {
if(u==0) no[no[u].nxt[i]].fail=0;
else {
int v=no[u].fail;
while(v!=-1) {
if(no[v].nxt[i]) {
no[no[u].nxt[i]].fail=no[v].nxt[i];
break;
}
v=no[v].fail;
}
if(v==-1) no[no[u].nxt[i]].fail=0;
}
q.push(no[u].nxt[i]);
}
}
}
}
int get(int root) {
int res=0;
while(root) {
res+=no[root].cnt;
no[root].cnt=0;
root=no[root].fail;
}
return res;
}
int match(char *s) {
int len=strlen(s),id,res=0,root=0;
for(int i=0; i<len; ++i) {
id=s[i]-'a';
if(no[root].nxt[id]) root=no[root].nxt[id];
else {
int p=no[root].fail;
while(p!=-1 && no[p].nxt[id]==0) p=no[p].fail;
if(p==-1) root=0;
else root=no[p].nxt[id];
}
if(no[root].cnt) res+=get(root);
}
return res;
}
} aho;
int T,n;
char s[maxs];
int main() {
std::ios::sync_with_stdio(0);
sc(n);
// aho.init();
for(int i=1; i<=n; ++i) {
scanf("%s",s);
aho.insert(s);
}
aho.build();
scanf("%s",s);
printf("%d\n",aho.match(s));
return 0;
}
发现UP主的代码好像过不了加强数据版,看了数据确实出现了错误。
索性这个基础版效率也不行,就再学习了一下拓扑优化版本。
P5357 【模板】AC自动机(二次加强版)
#include<bits/stdc++.h>
#define il inline
#define pb push_back
#define ms(_data,v) memset(_data,v,sizeof(_data))
#define sc(n) scanf("%d",&n)
#define SC(n,m) scanf("%d %d",&n,&m)
#define SZ(a) int((a).size())
#define rep(i,a,b) for(int i=a;i<=b;++i)
#define drep(i,a,b) for(int i=a;i>=b;--i)
using namespace std;
typedef long long ll;
const ll inf=0x3f3f3f3f;
const double pi=acos(-1.0);
const double eps=1e-9;
const int maxn=1e6+5;
const int maxs=2e6+5;
//il int Add(int x,int y) {return x+y>=mod?x+y-mod:x+y;}
//il int Mul(ll x,int y) {return x*y>=mod?x*y%mod:x*y;}
struct node{
int son[26],fail,flag,ans;
void init(){ms(son,0),fail=flag=ans=0;}
}trie[maxn];
queue<int> q;
int cnt,in[maxn],mp[maxn],res[maxn];
il void insert(char*s ,int num){
int root=1,len=strlen(s),id;
for(int i=0;i<len;++i){
id=s[i]-'a';
if(!trie[root].son[id]) trie[root].son[id]=++cnt;
root=trie[root].son[id];
}
if(!trie[root].flag) trie[root].flag=num;
mp[num]=trie[root].flag;
}
il void getFail(){
for(int i=0;i<26;++i) trie[0].son[i]=1;
q.push(1);
int u,ufail,v;
while(!q.empty()){
u=q.front(),ufail=trie[u].fail,q.pop();
for(int i=0;i<26;++i){
v=trie[u].son[i];
if(!v){
trie[u].son[i]=trie[ufail].son[i];
continue;
}
trie[v].fail=trie[ufail].son[i];
in[trie[v].fail]++,q.push(v);
}
}
}
il void pushup(){
for(int i=1;i<=cnt;++i) if(!in[i]) q.push(i);
int u,v;
while(!q.empty()){
u=q.front(),q.pop();
res[trie[u].flag]=trie[u].ans;
v=trie[u].fail,in[v]--;
trie[v].ans+=trie[u].ans;
if(in[v]==0) q.push(v);
}
}
il void match(char *s){
int root=1,len=strlen(s),id;
for(int i=0;i<len;++i){
id=s[i]-'a';
root=trie[root].son[id],trie[root].ans++;
}
pushup();
}
int T,n;
char s[maxs];
int main() {
std::ios::sync_with_stdio(0);
sc(n);
cnt=1;
for(int i=1; i<=n; ++i) {
scanf("%s",s);
insert(s,i);
}
getFail();
scanf("%s",s);
match(s);
for(int i=1;i<=n;++i){
printf("%d\n",res[mp[i]]);
}
return 0;
}