本蒟蒻来介绍一下三种方法吧
Solution 1(暴力&哈希)
第一种方法比较暴力,它的关键点在于:
所有字符串的总长不超过 m = 3 × 1 0 5 m=3\times10^5 m=3×105,不难想到最多有不超过 2 m 2\sqrt m 2m 种不同的长度。大概是不超过 800 的。所以遍历子串时只需要遍历 O ( m ) \operatorname O(\sqrt m) O(m) 种长度的子串就行了。
根据这一结论,我们可以暴力地维护 O ( m ) \operatorname O(\sqrt m) O(m) 个 set,里面存所有对应长度字符串的哈希值。
- 插入&删除:对该长度的 set 进行相应操作。 O ( log m ) \operatorname O(\log m) O(logm).
- 询问:暴力的遍历文本串 s s s 所有出现过的长度的子串,用滚动哈希每次去对应 set 里查找哈希值。 O ( m log m ) \operatorname O(\sqrt m \log m) O(mlogm).
这样一来就能暴力的解决。注意一下哈希的细节,最好开双哈希避免冲突,因为这题数据很强(真实体会)。
-
时间复杂度: O ( m m log m ) \operatorname O(m\sqrt m \log m) O(mmlogm)
-
空间复杂度: O ( m ) \operatorname O(m) O(m)
Code
#include <bits/stdc++.h>
#define ll long long
#define pb push_back
#define pii pair<int,int>
#define mp make_pair
#define F first
#define S second
using namespace std;
const ll p1=37,M1=998244353,p2=97,M2=1610612741;//双哈希常数&模数
int q,id[300005],sz[805],num;
ll pw1[300005],pw2[300005];
set<ll> se[805];//存不同长度哈希值的set
int main()
{
ios::sync_with_stdio(false),cin.tie(nullptr);
cin>>q;
memset(id,-1,sizeof(id));
pw1[0]=pw2[0]=1ll;
for(int i=1;i<=300000;i++)//预处理哈希常数的幂
{
pw1[i]=pw1[i-1]*p1%M1;
pw2[i]=pw2[i-1]*p2%M2;
}
while(q--)
{
int tp;string s;
cin>>tp>>s;
int len=s.size();
if (tp==1)
{
ll H1=0ll,H2=0ll;
for(int i=0;i<len;i++)//求双哈希值
{
H1=(H1*p1+(ll)(s[i]-'a'))%M1;
H2=(H2*p2+(ll)(s[i]-'a'))%M2;
}
if (id[len]==-1)
{
sz[num]=len;
id[len]=num++;
}
se[id[len]].insert(H1<<31|H2);//插入
}
else if (tp==2)
{
ll H1=0ll,H2=0ll;
for(int i=0;i<len;i++)
{
H1=(H1*p1+(ll)(s[i]-'a'))%M1;
H2=(H2*p2+(ll)(s[i]-'a'))%M2;
}
se[id[len]].erase(H1<<31|H2);//删除
}
else
{
int n=s.size(),res=0;
for(int i=0;i<num;i++)//枚举长度
{
int len=sz[i];
if (len>n)
continue;
ll H1=0ll,H2=0ll;
for(int j=0;j<len;j++)
{
H1=(H1*p1+(ll)(s[j]-'a'))%M1;
H2=(H2*p2+(ll)(s[j]-'a'))%M2;
}
res+=se[i].count(H1<<31|H2);
for(int j=len;j<n;j++)//滚动哈希遍历
{
H1=(H1*p1+(ll)(s[j]-'a')+M1-pw1[len]*(ll)(s[j-len]-'a')%M1)%M1;
H2=(H2*p2+(ll)(s[j]-'a')+M2-pw2[len]*(ll)(s[j-len]-'a')%M2)%M2;
res+=se[i].count(H1<<31|H2);
}
}
cout<<res<<endl;cout.flush();
}
}
return 0;
}
Solution 2 (暴力&trie 树&KMP)
这个方法不是我想的,是另一位巨佬写出来的。当时看的我大为震撼。
他的暴力思路是:将较小的模式串都插入到 trie 树中,较大的模式串单独扔在一遍用 KMP 匹配。
每次查询时,对文本串的所有后缀扔到 trie 里查询子串匹配个数,求总和。再把文本串与所有较大模式串单独 KMP。
看到这个方法能过的时候,我甚至开始思考 AC 自动机解法有什么存在的必要。
Code
#include <bits/stdc++.h>
using namespace std;
const int maxn = 300010;
struct Trie{//trie树的构体
public:
static const int maxn = 300010, maxc = 26;
int child[maxn][maxc], value[maxn];
int cnttrie, root;
int newnode(){
cnttrie++; value[cnttrie] = 0;
memset(child[cnttrie], -1, sizeof(child[cnttrie]));
return cnttrie;
}
public:
void init(){
cnttrie = root = 0; value[root] = 0;
memset(child[root], -1, sizeof(child[root]));
}
void insert(char *s,int val){
int x = root;
for (int i = 0; s[i]; i++){
int d = s[i] - 'a';
if (child[x][d] == -1){
child[x][d] = newnode();
}
x = child[x][d];
}
value[x] += val;
}
int search(char *s){
int sum = 0, x = root;
for (int i = 0; s[i]; i++){
int d = s[i] - 'a';
if (child[x][d] == -1){
break;
}
x = child[x][d];
sum += value[x];
}
return sum;
}
}trie;
class KMP{//KMP匹配的板子
private:
static const int maxn = 300010;
int next[maxn];
void getnext(string s, int len){
memset(next, 0, sizeof(next));
int i = 0, j = -1; next[0] = -1;
while (i < len){
if (j == -1 || s[i] == s[j]){
next[++i] = ++j;
}else{
j = next[j];
}
}
}
public:
int search(string s, char *str){
int ret = 0, j = 0;
int lens = s.size(), lenstr = strlen(str);
getnext(s, s.size());
for (int i = 0; i < lenstr; i++){
while (j > 0 && s[j] != str[i]){
j = next[j];
}
if (s[j] == str[i]){
j++;
}
if (j == lens){
ret++; j = next[j];
}
}
return ret;
}
}kmp;
string lib[50];
char str[maxn];
int cnt, cntx[50];
int main(){
int T = 1; //cin >> T;
int t, m;
int len, val;
while (T--){
cin >> m;
trie.init();
cnt = 0; memset(cntx, 0, sizeof(cntx));
for (int i = 1; i <= m; i++){
scanf("%d%s", &t, str);
len = strlen(str);
if (t == 1 || t == 2){
val = (t == 1 ? 1 : -1);//字符串权值,插入为1,删除为-1
if (len <= 1000){//模式串长度<=1000,插入trie
trie.insert(str, val);
}else{//长度>1000,单独扔到一边
lib[cnt] = string(str);
cntx[cnt++] = val;
}
}else{
long long ans = 0;
for (int i = 0; i < len; i++){//对所有后缀在trie上查找
ans += trie.search(str + i);
}
for (int i = 0; i < cnt; i++){//对所有单独的模式串KMP
if (lib[i].length() > len){ continue; }
ans += kmp.search(lib[i], str)*cntx[i];
}
printf("%lld\n", ans);
fflush(stdout);
}
}
}
return 0;
}
-
关于时间复杂度,大概是因为这里以 1000 为标准线分类的设计比较巧妙,能很好的控制 trie 上操作和 KMP 的复杂度都不会很高。
-
空间复杂度: O ( 26 m ) \operatorname O(26m) O(26m)
Solution 3 (二进制分组 &AC 自动机)
这个解法是最常规的,其他题解介绍的也差不多了。大致思想就是因为 AC 自动机比较难支持在线的修改查询,那样每次都要重新构造 fail 指针,复杂度会很大。所以考虑维护不超过 log m \log m logm 个 AC 自动机。第 i i i 个自动机储存 2 i 2^i 2i 个模式串,如下:
8 , 4 , 1 8,4,1 8,4,1
插入一个新模式串,就新建一组 AC 自动机。
8 , 4 , 1 , 1 8,4,1,1 8,4,1,1
然后只要前一个自动机大小与当前新自动机相等,就不断向前暴力合并。
8 , 4 , 2 8,4,2 8,4,2
如果再插入:
8 , 4 , 2 , 1 8,4,2,1 8,4,2,1
8 , 4 , 2 , 1 , 1 → 8 , 4 , 2 , 2 → 8 , 4 , 4 → 8 , 8 → 16 8,4,2,1,1\rightarrow 8,4,2,2\rightarrow 8,4,4\rightarrow 8,8 \rightarrow 16 8,4,2,1,1→8,4,2,2→8,4,4→8,8→16
至于合并的方式,对每个 AC 自动机维护一个 vector 容器存里面所有的模式串,然后只要前面可以合并,就将其中所有模式串拿出来向前合并即可。每次合并结束只需要对最后一个被合并的 AC 自动机重新构建 fail 指针。
删除操作:可以视为插入了一个权值为 -1 的字符串,同样用新建自动机向前合并的方式。
查询操作:遍历所有 AC 自动机,进行匹配,答案为匹配个数总和。
可以看出一个模式串最多被向前合并 log m \log m logm 次,且最多存在 log m \log m logm 组 AC 自动机。
-
时间复杂度: O ( m log m ) \operatorname O(m\log m) O(mlogm)
-
空间复杂度: O ( 26 m ) O(26m) O(26m)
一些细节 (血的教训):
- 如果每个 AC 自动机都是用数组的形式,那么空间复杂度是 O ( 26 m log m ) \operatorname O(26m\log m) O(26mlogm),是会 MLE 的。所以这里要用动态开点的 AC 自动机,用 vector 存储数据。
- 合并的过程要注意保证每个字符串只被拷贝一次,否则复杂度没法控制。
- 因为本题数据较强,AC 自动机部分的代码一定要控制常数。
- AC 自动机查询的函数中,不能暴力跳 fail 指针,负责会被第 27 组数据卡掉。要注意在构建 fail 指针的过程中,预处理一个点以及它 fail 指针后面的所有点权值总和,这样能避免跳 fail,保证每个点查询过程中只被遍历一次。
Code
#include <bits/stdc++.h>
#define ll long long
#define pb push_back
#define pii pair<int,int>
#define mp make_pair
#define F first
#define S second
using namespace std;
int m,pr;
struct AC//AC自动机结构体
{
int trienum;
vector<int> cnt,tot,son[30],fail;//动态开点数组
vector<pair<string,int> > v; //储存内部所有字符串及其权值的 vector
void init()//初始化
{
v.clear();
cnt.clear(),fail.clear(),tot.clear();
trienum=0;
cnt.pb(0),fail.pb(0),tot.pb(0);
for(int i=0;i<26;i++)
{
son[i].clear();
son[i].pb(0);
}
}
void ins(string t,int x)//插入
{
v.pb(mp(t,x));
int p=0;
for(int i=0;i<(int)t.size();i++)
{
if (son[t[i]-'a'][p]<=0)
{
for(int j=0;j<26;j++)
son[j].pb(0);
cnt.pb(0),fail.pb(0),tot.pb(0);
son[t[i]-'a'][p]=++trienum;
}
p=son[t[i]-'a'][p];
}
cnt[p]+=x;
}
void build()//bfs构建fail指针
{
tot.clear();
tot.resize(trienum+1,0);
queue<int> q;
for(int i=0;i<26;i++)
if (son[i][0]>0)
{
q.push(son[i][0]);
tot[son[i][0]]=cnt[son[i][0]];
}
while(!q.empty())
{
int u=q.front();
q.pop();
for(int i=0;i<26;i++)
if (son[i][u]>0)
{
fail[son[i][u]]=abs(son[i][fail[u]]);
tot[son[i][u]]=cnt[son[i][u]]+tot[fail[son[i][u]]];//cnt是单点权值,tot是这个点和fail指针之后的所有点权值总和
q.push(son[i][u]);
}
else//这一步是构造trie图,能方便失配时一步到达fail指针的位置,但是这里用负数标记,表示实际上这个儿子是不存在的
son[i][u]=-abs(son[i][fail[u]]);
}
}
ll query(string t)//查询
{
int u=0;
ll res=0ll;
for(int i=0;i<(int)t.size();i++)
{
u=abs(son[t[i]-'a'][u]);
res+=tot[u];//避免了跳fail的过程
}
return res;
}
}ac[25];
int main()
{
int tp;
char s[300005];
scanf("%d",&m);
while(m--)
{
scanf("%d%s",&tp,s);
if (tp<3)
{
int val=(tp==1?1:-1);//字符串权值,插入为1,删除为-1
int cur=0,lst=pr;
ac[pr].init();
ac[pr].ins(s,val);
while(pr&&(int)ac[pr-1].v.size()==(int)ac[pr].v.size()+cur)//向前合并
cur+=(int)ac[pr--].v.size();
for(int i=pr+1;i<=lst;i++)
for(auto p:ac[i].v)
ac[pr].ins(p.F,p.S);
ac[pr++].build();
}
else
{
ll res=0ll;
for(int i=0;i<pr;i++)
res+=ac[i].query(s);
printf("%lld\n",res);
fflush(stdout);
}
}
return 0;
}