题意:
一个字符串集合,初始为空, m m m 次操作:① 1 , s 1, s 1,s,加入字符串 s s s,保证 s s s 此前未在集合里;② 2 , s 2, s 2,s,删除字符串 s s s;③ 3 , s 3, s 3,s,询问集合中字符串在 s s s 中出现的次数总和。强制在线。 ( m , ∑ ∣ s i ∣ ≤ 3 × 1 0 5 ) (m, \sum |s_i| \leq 3×10^5) (m,∑∣si∣≤3×105)
链接:
https://codeforces.com/problemset/problem/710/F
解题思路:
不是强制在线的话,可以直接离线把模式串全部拿出来建 A C AC AC 自动机,再配合 f a i l fail fail 树处理每个询问。现在强制在线,显然不能操作 ③ 都建立一次 A C AC AC 自动机,所以,可以分块建立 A C AC AC 自动机,对前缀每隔 n \sqrt{n} n 建立一个 A C AC AC 自动机,这样固然可以,但复杂度过高。
这里使用二进制分组,将询问所需的 A C AC AC 自动机的数量降低到 l o g n logn logn,简单说,将前缀 i i i 用二进制表示,每个二进制位建立一个 A C AC AC 自动机, i → i + 1 i \rightarrow i + 1 i→i+1 时,新建一个代表 s i + 1 s_{i + 1} si+1 的自动机代表,然后对 i i i 和这个 1 1 1 合并,比如 i = 9 i = 9 i=9,那么 1001 → 1010 1001 \rightarrow 1010 1001→1010,这样询问的自动机个数就变成了 l o g n logn logn,而每个串代表的自动机每次暴力合并都往高位进 1 1 1,进位次数为 l o g n logn logn 的,也就是每个串至多被重构 l o g n logn logn 次。删除操作可以通过另建一个自动机,询问时作差。( p s : ps: ps: 需要注意 f a i l fail fail 指针的构建时机)
最初的做法是对询问串长分类, 对 ∣ s i ∣ > n |s_i| \gt \sqrt{n} ∣si∣>n 的串,直接把集合里的串拿出来建立 A C AC AC 自动机,这种串至多 n \sqrt{n} n 个,那么这部分复杂度是 O ( n n ∗ Σ ) O(n\sqrt{n} * \Sigma) O(nn∗Σ)。对于 ∣ s i ∣ ≤ n |s_i| \leq \sqrt{n} ∣si∣≤n 的串,由于子串数量是 O ( ∣ s i ∣ 2 ) O(|s_i|^2) O(∣si∣2) 的,那么可以直接哈希统计,但更好的做法可以是先将字符串集合建立字典树,拿 s i s_i si 的 n n n 个后缀去字典树上匹配所有前缀,这样这部分复杂度是 O ( n n ) O(n\sqrt{n}) O(nn) 的。
参考代码:
二进制分组:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
#define sz(a) ((int)a.size())
#define pb push_back
#define lson (rt << 1)
#define rson (rt << 1 | 1)
#define gmid (l + r >> 1)
const int maxn = 3e5 + 5;
const int inf = 0x3f3f3f3f;
const int mod = 1e9 + 7;
struct AC{
int nxt[maxn][26], fail[maxn], tag[maxn], num[maxn];
int trans[maxn][26], rt[21], cnt;
int add(){
++cnt; tag[cnt] = num[cnt] = fail[cnt] = 0; return cnt;
}
void init(){
cnt = 0;
for(int i = 0; i <= 20; ++i) rt[i] = 0;
}
int merge(int &x, int &y){
if(!x || !y) return x + y;
tag[x] += tag[y];
for(int i = 0; i < 26; ++i){
nxt[x][i] = merge(nxt[x][i], nxt[y][i]);
}
return x;
}
void insert(char *s){
int r = add(), p = r;
while(*s){
int t = *s - 'a';
if(!nxt[p][t]) nxt[p][t] = add();
p = nxt[p][t];
++s;
}
tag[p] = 1;
for(int i = 0; i <= 20; ++i){
if(rt[i]) merge(r, rt[i]), rt[i] = 0;
else { rt[i] = r, cFail(rt[i]); break; }
}
}
void cFail(int rt){
queue<int> q;
for(int i = 0; i < 26; ++i){
int v = nxt[rt][i]; trans[rt][i] = v;
if(v) fail[v] = rt, num[v] = tag[v], q.push(v);
else trans[rt][i] = rt;
}
while(!q.empty()){
int u = q.front(); q.pop();
for(int i = 0; i < 26; ++i){
int v = nxt[u][i]; trans[u][i] = v;
if(v) fail[v] = trans[fail[u]][i], num[v] = tag[v] + num[fail[v]], q.push(v);
else trans[u][i] = trans[fail[u]][i];
}
}
}
int run(int rt, char *s){
int ret = 0;
while(*s){
rt = trans[rt][*s - 'a'];
ret += num[rt];
++s;
}
return ret;
}
int solve(char *s){
int ret = 0;
for(int i = 0; i <= 20; ++i){
if(rt[i]) ret += run(rt[i], s);
}
return ret;
}
} ac1, ac2;
char s[maxn];
int n, m;
int main() {
// ios::sync_with_stdio(0); cin.tie(0);
scanf("%d", &m);
ac1.init(), ac2.init();
while(m--){
int opt; scanf("%d%s", &opt, s);
if(opt == 1){
ac1.insert(s);
}
else if(opt == 2){
ac2.insert(s);
}
else{
int ret = ac1.solve(s) - ac2.solve(s);
printf("%d\n", ret);
fflush(stdout);
}
}
return 0;
}
串长分类:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
#define sz(a) ((int)a.size())
#define pb push_back
#define lson (rt << 1)
#define rson (rt << 1 | 1)
#define gmid (l + r >> 1)
const int maxn = 3e5 + 5;
const int inf = 0x3f3f3f3f;
const int mod = 1e9 + 7;
struct AC{
int nxt[maxn][26], fail[maxn], num[maxn], cnt;
int add(){
++cnt; memset(nxt[cnt], 0, sizeof nxt[cnt]);
fail[cnt] = num[cnt] = 0; return cnt;
}
void init(){
cnt = -1; add();
}
void insert(char *s, int flg){
int p = 0;
while(*s){
int t = *s - 'a';
if(!nxt[p][t]) nxt[p][t] = add();
p = nxt[p][t];
++s;
}
num[p] += flg;
}
void cFail(){
queue<int> q;
for(int i = 0; i < 26; ++i) if(int v = nxt[0][i]) q.push(v);
while(!q.empty()){
int u = q.front(); q.pop();
for(int i = 0; i < 26; ++i){
if(int v = nxt[u][i]) fail[v] = nxt[fail[u]][i], num[v] += num[fail[v]], q.push(v);
else nxt[u][i] = nxt[fail[u]][i];
}
}
}
ll run(char *s, int flg){
ll ret = 0; int p = 0;
while(*s){
p = nxt[p][*s - 'a'];
ret += num[p];
++s;
if(flg && !p) break;
}
return ret;
}
} trie, ac;
char s[maxn];
int n, m;
int main() {
// ios::sync_with_stdio(0); cin.tie(0);
scanf("%d", &m);
trie.init();
while(m--){
int opt; scanf("%d%s", &opt, s);
if(opt == 1){
trie.insert(s, 1);
}
else if(opt == 2){
trie.insert(s, -1);
}
else{
ll ret = 0;
n = strlen(s);
if(n <= 600){
for(int i = 0; i < n; ++i){
ret += trie.run(s + i, 1);
}
}
else{
ac = trie;
ac.cFail();
ret = ac.run(s, 0);
}
printf("%d\n", ret);
fflush(stdout);
}
}
return 0;
}