题意
给出n个单词组成的字典,给出m个文本,对于每个文本求出字典中出现的单词个数,保证每个文本最多出现3个字典中的单词。
题解
AC自动机模板题,这道题卡内存。
因为是单组数据,没有必要初始化。我的代码无论用什么初始化,都MLE。
下面的代码过了,还有优化的余地,ans可以不用开数组,直接一个vector就可以了。还可以手写队列。
代码
#include <bits/stdc++.h>
using namespace std;
const int nmax = 1e5 + 7;
const int sigma = 128;
vector<int> ans[1005];
struct Aho {
int sz;
queue<int> que;
struct Node {
int nxt[sigma];
int cnt, fail, id;
} node[nmax];
void init() {
// while(!que.empty()) que.pop();
// for (int i = 0; i < nmax; ++i) {
// for (int j = 0; j < sigma; ++j)
// node[i].nxt[j] = 0;
// node[i].cnt = node[i].fail = node[i].id = 0;
// }
// memset(node, 0, sizeof node);
sz = 1;
}
inline int idx(const char & ch) {return ch;}
void insert(char * S, int id) {
int len = strlen(S);
int now = 0;
for(int i = 0; i < len; ++i) {
char ch = S[i];
if(!node[now].nxt[idx(ch)])
node[now].nxt[idx(ch)] = sz++;
now = node[now].nxt[idx(ch)];
}
node[now].cnt ++;
node[now].id = id;
}
void build_fail() {
node[0].fail = -1;
que.push(0);
while(!que.empty()) {
int u = que.front(); que.pop();
for(int i = 0; i < sigma; ++i) {
if(node[u].nxt[i]) {
if(u == 0) node[node[u].nxt[i]].fail = 0;
else {
int v = node[u].fail;
while(v != -1) {
if(node[v].nxt[i]) {
node[node[u].nxt[i]].fail = node[v].nxt[i];
break;
}
v = node[v].fail;
}
if(v == -1) node[node[u].nxt[i]].fail = 0;
}
que.push(node[u].nxt[i]);
}
}
}
}
void Get(int u, int ask_id) {
while(u) {
if(node[u].cnt) ans[ask_id].push_back(node[u].id);
u = node[u].fail;
}
}
void match(char * S, int ask_id) {
int len = strlen(S);
int now = 0;
for(int i = 0; i < len; ++i) {
char ch = S[i];
if(node[now].nxt[idx(ch)])
now = node[now].nxt[idx(ch)];
else {
int fa = node[now].fail;
while(fa != -1 && !node[fa].nxt[idx(ch)]) fa = node[fa].fail;
if(fa == -1) now = 0;
else now = node[fa].nxt[idx(ch)];
}
if(node[now].cnt)
Get(now, ask_id);
}
}
} aho;
int n,m,cnt;
char str[nmax];
int main() {
aho.init();
scanf("%d", &n);
for(int i = 0; i < n; ++i) {
scanf("%s", str);
aho.insert(str, i+1);
}
aho.build_fail();
scanf("%d", &m);
for(int i = 0; i < m; ++i) {
scanf("%s", str);
aho.match(str, i);
if(ans[i].size()) {
cnt ++;
printf("web %d: ", i+1);
sort(ans[i].begin(),ans[i].end());
for(int j = 0; j < ans[i].size(); ++j) {
if(j == 0) printf("%d", ans[i][j]);
else printf(" %d", ans[i][j]);
}
printf("\n");
}
}
printf("total: %d\n", cnt);
}