题意:给一个长度为 n n n 的串 S S S 和一个长度为 b b b 的串 B B B,有 m m m 个文本串,初始它们都是空串。需要支持 q q q 个操作,每个操作要么是给某个文本串后面接上串 B [ l , r ] B[l,r] B[l,r],要么是询问某个文本串在 S S S 中的出现次数。
题解:
一开始的想法是后缀自动机,但 “给一个文本串接上 B [ l , r ] B[l,r] B[l,r]” 相当于在自动机上走 r − l + 1 r-l+1 r−l+1 步,这个可能比较复杂。
考虑使用后缀数组,一个文本串 T i T_i Ti 在 S S S 中的出现次数相当于这个文本串作为 S S S 的多少个后缀的前缀出现,这些前缀包含 T i T_i Ti 的后缀在后缀数组中肯定是一个连续的区间,称为串 T i T_i Ti 对应的后缀区间。我们考虑动态维护每个文本串 T i T_i Ti 的后缀区间 [ L i , R i ] [L_i,R_i] [Li,Ri]。
现在的问题是如何合并两个串,即我们已经知道了串 A A A 对应的后缀区间 [ L 1 , R 1 ] [L_1,R_1] [L1,R1] 和串 B B B 对应的后缀区间 [ L 2 , R 2 ] [L_2,R_2] [L2,R2],如何快速求出串 A + B A+B A+B 对应的后缀区间 [ L 3 , R 3 ] [L_3,R_3] [L3,R3]。因为知道了这个之后, B [ l , r ] B[l,r] B[l,r] 对应的后缀区间就能用线段树求出,而在文本串末尾接上 B [ l , r ] B[l,r] B[l,r] 之后对应的后缀区间也能快速求出。
可以暴力使用二分+哈希check,能够实现 O ( log 2 n ) O(\log ^2n) O(log2n) 的合并。但事实上 check 可以更快:注意到 [ L 3 , R 3 ] [L_3,R_3] [L3,R3] 一定是 [ L 1 , R 1 ] [L_1,R_1] [L1,R1] 的子区间,所以我们在 [ L 1 , R 1 ] [L_1,R_1] [L1,R1] 内先直接二分 L 3 L_3 L3。现在要判断 S S S 的一个后缀 S [ s a m i d , n ] S[sa_{mid},n] S[samid,n] 是否小于串 A + B A+B A+B,这里使用 Hash 判断就可以做到刚刚说的 O ( log n ) O(\log n) O(logn) check,但事实上,由于 m i d ∈ [ L 1 , R 1 ] mid\in [L_1,R_1] mid∈[L1,R1], S [ s a m i d , n ] S[sa_{mid},n] S[samid,n] 肯定包含前缀 A A A,所以我们要比较的实质是 S [ s a m i d + ∣ A ∣ , n ] S[sa_{mid}+|A|,n] S[samid+∣A∣,n] 和串 B B B,这可以通过比较 s a m i d + ∣ A ∣ sa_{mid}+|A| samid+∣A∣ 和 [ L 2 , R 2 ] [L_2,R_2] [L2,R2] 做到 O ( 1 ) O(1) O(1) 比较。于是 check 变为 O ( 1 ) O(1) O(1),二分 L 3 L_3 L3 就变为 O ( log n ) O(\log n) O(logn) 的,然后再二分 R 3 R_3 R3 即可,合并 A , B A,B A,B 的总时间复杂度就变为 O ( log n ) O(\log n) O(logn)。
事实上,我们可以把 S S S 和 B B B 中间隔一个特殊字符拼在一起得到一个长串 S ′ S' S′,然后在这个串上做后缀数组。这样 B [ l , r ] B[l,r] B[l,r] 对应的后缀区间就可以不用线段树,而是直接二分得到(因为 B [ l , b ] B[l,b] B[l,b] 已经是 S ′ S' S′ 的一个包含前缀 B [ l , r ] B[l,r] B[l,r] 的后缀了)。注意此时询问时要输出的是后缀区间 [ l , r ] [l,r] [l,r] 中起始位置 ≤ n \leq n ≤n 的后缀个数。
总时间复杂度降为 O ( n + b + q log ( n + b ) ) O(n+b+q\log (n+b)) O(n+b+qlog(n+b))。
#include<bits/stdc++.h>
#define N 500010
#define fi first
#define se second
#define pii pair<int,int>
#define mk(a,b) make_pair(a,b)
using namespace std;
namespace modular
{
const int mod=998244353;
inline int add(int x,int y){return x+y>=mod?x+y-mod:x+y;}
inline int dec(int x,int y){return x-y<0?x-y+mod:x-y;}
inline int mul(int x,int y){return 1ll*x*y%mod;}
}using namespace modular;
inline int read()
{
int x=0,f=1;
char ch=getchar();
while(ch<'0'||ch>'9')
{
if(ch=='-') f=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9')
{
x=(x<<1)+(x<<3)+(ch^'0');
ch=getchar();
}
return x*f;
}
const int base=19260817;
struct Query
{
int opt,x;
}q[N];
int ns,m,Q,n;
int poww[N<<1],sum[N<<1];
int sa[N<<1],rk[N<<1],inS[N<<1];
char s[N<<1];
namespace SA
{
const int NN=N<<2;
int node=1,ch[NN][27],len[NN],fa[NN],lefpos[NN],st[NN];
namespace Tree
{
vector<pii>e[NN];
void dfs1(int u)
{
for(pii &now:e[u])
{
int v=now.se;
dfs1(v);
now.fi=s[lefpos[v]+len[u]]-'a';
if(!lefpos[u]) lefpos[u]=lefpos[v];
}
}
int tot;
void dfs2(int u)
{
sort(e[u].begin(),e[u].end());
if(st[u]) sa[++tot]=st[u],rk[st[u]]=tot;
for(pii now:e[u])
{
int v=now.se;
dfs2(v);
}
}
void work()
{
for(int i=2;i<=node;i++)
e[fa[i]].push_back(mk(0,i));
dfs1(1);
dfs2(1);
}
}
namespace SAM
{
int last=1;
void insert(int c,int nid)
{
int p=last,now=last=++node;
len[now]=len[p]+1;
lefpos[now]=st[now]=nid;
for(;p&&!ch[p][c];p=fa[p]) ch[p][c]=now;
if(!p) fa[now]=1;
else
{
int q=ch[p][c];
if(len[p]+1==len[q]) fa[now]=q;
else
{
int nq=++node;
memcpy(ch[nq],ch[q],sizeof(ch[nq]));
len[nq]=len[p]+1;
fa[nq]=fa[q],fa[q]=nq,fa[now]=nq;
for(;p&&ch[p][c]==q;p=fa[p]) ch[p][c]=nq;
}
}
}
void work()
{
for(int i=n;i>=1;i--)
insert(s[i]-'a',i);
}
}
void build()
{
SAM::work();
Tree::work();
}
}
struct Hash
{
int val,len;
Hash(){};
Hash(int a,int b){val=a,len=b;}
};
inline Hash operator + (Hash a,Hash b)
{
return Hash(add(mul(a.val,poww[b.len]),b.val),a.len+b.len);
}
inline int query(int l,int r)
{
return dec(sum[r],mul(sum[l-1],poww[r-l+1]));
}
struct data
{
Hash h;
int l,r;
}p[N];
inline int getsmall(int l,int r,int lenA,int Bl)
{
int ans=l-1;
while(l<=r)
{
int mid=(l+r)>>1;
int p=sa[mid]+lenA;
if(rk[p]<Bl) ans=mid,l=mid+1;
else r=mid-1;
}
return ans;
}
inline int getsamel(int l,int r,Hash h)
{
int ans=r+1;
while(l<=r)
{
int mid=(l+r)>>1;
if(sa[mid]+h.len-1<=n&&h.val==query(sa[mid],sa[mid]+h.len-1)) ans=mid,r=mid-1;
else l=mid+1;
}
return ans;
}
inline int getsamer(int l,int r,Hash h)
{
int ans=l-1;
while(l<=r)
{
int mid=(l+r)>>1;
if(sa[mid]+h.len-1<=n&&h.val==query(sa[mid],sa[mid]+h.len-1)) ans=mid,l=mid+1;
else r=mid-1;
}
return ans;
}
inline data merge(data a,data b)
{
data c;
c.h=a.h+b.h;
c.l=getsmall(a.l,a.r,a.h.len,b.l)+1;
c.r=getsamer(c.l,a.r,c.h);
return c;
}
inline data getT(int tl,int tr)
{
tl+=ns+1,tr+=ns+1;
const int p=rk[tl];
data res;
res.h=Hash(query(tl,tr),tr-tl+1);
res.l=getsamel(0,p,res.h);
res.r=getsamer(p,n,res.h);
return res;
}
int in[N];
int main()
{
// freopen("ex_2.in","r",stdin);
// freopen("ex_2.out","w",stdout);
read(),ns=n=read(),m=read(),Q=read();
scanf("%s",s+1);
s[++n]='z'+1;
for(int i=1;i<=Q;i++)
{
q[i].opt=read();
if(q[i].opt==2)
{
char str[2];
scanf("%s",str);
s[++n]=str[0];
}
else q[i].x=read();
}
poww[0]=1;
for(int i=1;i<=n;i++)
{
poww[i]=mul(poww[i-1],base);
sum[i]=add(mul(sum[i-1],base),s[i]-'a'+1);
}
SA::build();
for(int i=1;i<=n;i++)
inS[i]=inS[i-1]+(sa[i]<=ns);
for(int i=1;i<=m;i++)
p[i].h=Hash(0,0),p[i].l=0,p[i].r=n;
int ntim=0;
for(int i=1;i<=Q;i++)
{
if(q[i].opt==2)
{
ntim++;
continue;
}
if(q[i].opt==1)
{
if(in[q[i].x])
{
p[q[i].x]=merge(p[q[i].x],getT(in[q[i].x],ntim));
in[q[i].x]=0;
}
else in[q[i].x]=ntim+1;
}
if(q[i].opt==3)
{
data now=p[q[i].x];
if(in[q[i].x])
now=merge(now,getT(in[q[i].x],ntim));
if(now.l) printf("%d\n",inS[now.r]-inS[now.l-1]);
else puts("0");
}
}
return 0;
}