ACwing 1285. 单词
原理是任一一个字串,可以对原字符串的前缀的后缀取得!
题解
#include <bits/stdc++.h>
using namespace std;
const int MAX_N = 1e6 + 10;
int pos[MAX_N]; //第i个字符串在tire结束的节点位置
//从下标1开始
int N;
namespace AC {
int tot, tr[MAX_N][26];//tire树
int fail[MAX_N];//不管
int q[MAX_N];//队列
int cnt[MAX_N];//个数
void init() {
memset(fail, 0, sizeof(fail));
memset(tr, 0, sizeof(tr));
memset(cnt, 0, sizeof(cnt));
memset(q, 0, sizeof(q));
tot = 0;
}
void insert(char *s, int id) {
int u = 0;
for(int i = 1; s[i]; i++) {
int t = s[i] - 'a';
if(!tr[u][t]) tr[u][t] = ++tot;
u = tr[u][t];
cnt[u]++;
}
pos[id] = u;
}
void build() {
int hh = 0, tt = -1;
for (int i = 0; i < 26; i++)
if (tr[0][i]) q[++tt] = tr[0][i];
while (hh <= tt) {
int t = q[hh++];
for (int i = 0; i < 26; i++) {
int j = tr[t][i];
if (!j) tr[t][i] = tr[fail[t]][i];
else {
fail[j] = tr[fail[t]][i];
q[++tt] = j;
}
}
}
}
void build_egde(){//反向建边
for(int i = tot - 1; i >= 0; i--) {
cnt[fail[q[i]]] += cnt[q[i]];
}
}
}
const int MAX = 1e6 + 10;
char ch[MAX];
int main() {
int n;
scanf("%d", &n);
AC::init();
for(int i = 1; i <= n; i++) {
scanf("%s", ch + 1);
AC::insert(ch, i);//插入节点位置,就不用用char数组存了
//1e2 * 1e6 会爆
//除非写string数组,然后strcpy(ch + 1, s[i].c_str());
}
AC::build();
AC::build_egde();
for(int i = 1; i <= n; i++) {
printf("%d\n", AC::cnt[pos[i]]);
}
}
再分享一个写法
#include <bits/stdc++.h>
using namespace std;
const int MAX_N = 5e5 + 10;
int pos[MAX_N]; //第i个字符串在tire结束的节点位置
//从下标1开始
int N;
namespace AC {
int tot, tr[MAX_N][26];//tire树
int fail[MAX_N];//不管
int idx[MAX_N];//节点编号
int q[MAX_N];//队列
int cnt[MAX_N];//个数
void init() {
memset(fail, 0, sizeof(fail));
memset(tr, 0, sizeof(tr));
memset(idx, 0, sizeof(idx));
tot = 0;
}
void insert(char *s, int id) {
int u = 0;
for (int i = 1; s[i]; i++) {
if (!tr[u][s[i] - 'a']) tr[u][s[i] - 'a'] = ++tot;
u = tr[u][s[i] - 'a'];
cnt[u]++;
}
idx[u]++;
pos[id] = u;
}
void build() {
int hh = 0, tt = -1;
for (int i = 0; i < 26; i++)
if (tr[0][i]) q[++tt] = tr[0][i];
while (hh <= tt) {
int t = q[hh++];
for (int i = 0; i < 26; i++) {
int j = tr[t][i];
if (!j) tr[t][i] = tr[fail[t]][i];
else {
fail[j] = tr[fail[t]][i];
q[++tt] = j;
}
}
}
}
}
char ch[MAX_N];
vector<int> g[MAX_N];
int val[MAX_N];
void dfs(int u) {
val[u] = AC::cnt[u];
for (auto x:g[u]) {
dfs(x);
val[u] += val[x];
}
}
int main() {
int n;
memset(AC::q, 0, sizeof AC::q);
scanf("%d", &n);
AC::init();
for (int i = 0; i < n; i++) {
scanf("%s", ch + 1);
AC::insert(ch, i);//插入节点位置,就不用用char数组存了
//1e2 * 1e6 会爆
//除非写string数组,然后strcpy(ch + 1, s[i].c_str());
}
AC::build();
for (int i = 1; i <= AC::tot; i++) {
g[AC::fail[i]].push_back(i);
}
dfs(0);
for (int i = 0; i < n; i++) {
printf("%d\n", val[pos[i]]);
}
}