如果只有查询,那么建一个后缀自动机,建完后扫一遍求每个点子树里的叶节点个数(right集合大小)。
然后查询就从根开始走trs指针。
不过这题还有在串后补字符的操作,加字符时维护子树叶节点个数如果直接往上修改就变成
O(n2)
的了。(听说加了一组数据。。。)
由于每次修改都是单点或是一条链,因此可以用LCT维护父亲的关系和一个点子树中叶子个数。
#include <bits/stdc++.h>
using namespace std;
#define N 3100000
#define M 1210000
#define which(x) (ch[fa[x]][1]==x)
#define ls(x) ch[x][0]
#define rs(x) ch[x][1]
int q,mask;
char s[N],s1[11];
struct LCT
{
int ch[M][2],fa[M],sum[M],rev[M],bj[M];
int isroot(int x)
{return !fa[x]||ch[fa[x]][which(x)]!=x;}
void pushdown(int x)
{
if(rev[x])
{
swap(ls(x),rs(x));
rev[ls(x)]^=1;rev[rs(x)]^=1;
rev[x]=0;
}
if(bj[x])
{
bj[ls(x)]+=bj[x];sum[ls(x)]+=bj[x];
bj[rs(x)]+=bj[x];sum[rs(x)]+=bj[x];
bj[x]=0;
}
}
void down(int x)
{
if(!isroot(x))down(fa[x]);
pushdown(x);
}
void rotate(int x)
{
int y=fa[x],k=which(x);
ch[y][k]=ch[x][k^1];
ch[x][k^1]=y;
if(!isroot(y))ch[fa[y]][which(y)]=x;
fa[x]=fa[y];fa[y]=x;
fa[ch[y][k]]=y;
}
void splay(int x)
{
down(x);
while(!isroot(x))
{
int y=fa[x];
if(isroot(y))rotate(x);
else
{
if(which(x)^which(y))rotate(x);
else rotate(y);
rotate(x);
}
}
}
void access(int x)
{
int t=0;
while(x)
{
splay(x);
ch[x][1]=t;
t=x;x=fa[x];
}
}
void rever(int x)
{
access(x);splay(x);
rev[x]^=1;
}
void link(int x,int y)
{
rever(x);
fa[x]=y;
}
void cut(int x,int y)
{
rever(x);
access(y);splay(y);
fa[x]=ch[y][0]=0;
}
void add(int x)
{
rever(1);
access(x);splay(x);
bj[x]++;sum[x]++;
}
int get(int x)
{
access(x);splay(x);
return sum[x];
}
}lct;
struct SAM
{
int trs[M][27],len[M],fa[M];
int last,cnt;
void init(){last=cnt=1;}
void fat(int x,int y)
{
if(fa[x])lct.cut(x,fa[x]);
fa[x]=y;
lct.link(x,y);
}
void insert(int x)
{
int p=last,np=++cnt,q,nq;
last=np;len[np]=len[p]+1;
for(;p&&!trs[p][x];p=fa[p])trs[p][x]=np;
if(!p)fat(np,1);
else
{
q=trs[p][x];
if(len[q]==len[p]+1)fat(np,q);
else
{
fat(nq=++cnt,fa[q]);
len[nq]=len[p]+1;
lct.sum[nq]+=lct.get(q);
memcpy(trs[nq],trs[q],sizeof(trs[q]));
fat(q,nq);fat(np,nq);
for(;p&&trs[p][x]==q;p=fa[p])trs[p][x]=nq;
}
}
lct.add(np);
}
void insert(char *s)
{
int len=strlen(s+1);
for(int i=1;i<=len;i++)
insert(s[i]-'A'+1);
}
int query(char *s)
{
int len=strlen(s+1),now=1;
for(int i=1;i<=len;i++)
now=trs[now][s[i]-'A'+1];
if(!now)return 0;
return lct.get(now);
}
void print()
{
puts("SUM:");
for(int i=1;i<=cnt;i++)
printf("%d %d\n",i,fa[i]);
puts("");
}
}sam;
void decode(char *s,int mask)
{
int len=strlen(s);
for(int i=0;i<len;i++)
{
mask=(mask*131+i)%len;
swap(s[i],s[mask]);
}
}
void print()
{
puts("LCT:");
for(int i=1;i<=sam.cnt;i++)
{
if(lct.fa[i])
printf("%d %d %d\n",i,lct.fa[i],!lct.isroot(i));
}
}
int main()
{
sam.init();
scanf("%d%s",&q,s+1);
sam.insert(s);
while(q--)
{
scanf("%s%s",s1+1,s+1);
decode(s+1,mask);
if(s1[1]=='A')
sam.insert(s);
else
{
int t=sam.query(s);
mask^=t;
printf("%d\n",t);
}
}
return 0;
}