题目大意:初始字串为空,首先给定一系列操作序列,有三种操作:
1.在结尾加一个字符
2.在结尾删除一个字符
3.打印当前字串
然后多次询问第x个打印的字串在第y个打印的字串中出现了几次
思路:
构造出fail树,dfs序标号,查询按y排序,树状数组单点修改区间查询
#include<cstring>
#include<algorithm>
#include<cstring>
#include<vector>
#include<cstdio>
using namespace std;
const int N=100100;
char s[N];
int q,ans[N],L[N],R[N],id=0;
struct node{
int x,y,id;
}a[N];
struct Edge{
int to,next;
}e[N];
int tot,head[N];
void init(){
tot=0;
memset(head,-1,sizeof(head));
}
void addedge(int from,int to){
e[tot]=(Edge){to,head[from]};
head[from]=tot++;
}
struct Trie{
int sz,val[N],next[N][26],fail[N],root,C[N],Q[N],ed[N];
int S[N];
int newnode(){
memset(next[sz],-1,sizeof(next[sz]));
val[sz]=0,sz++;
return sz-1;
}
void init(){
sz=0;
root=newnode();
}
void insert(){
int now=root,len=strlen(s),top=0,cnt=0;
for(int i=0;i<len;i++){
if(s[i]=='B'){
now=S[--top];
continue;
}
if(s[i]=='P'){
++cnt;
ed[cnt]=now;
continue;
}
if(next[now][s[i]-'a']==-1)
next[now][s[i]-'a']=newnode();
now=next[now][s[i]-'a'];
S[++top]=now;
}
}
void build(){
int head=0,tail=0;
fail[root]=root;
for(int i=0;i<26;i++)
if(next[root][i]!=-1)
fail[next[root][i]]=root,Q[++tail]=next[root][i],addedge(root,next[root][i]);
while(head<tail){
int now=Q[++head];
for(int i=0;i<26;i++)
if(next[now][i]!=-1){
int k=fail[now];
while(k!=root&&next[k][i]==-1) k=fail[k];
if(next[k][i]!=-1)
k=next[k][i];
fail[next[now][i]]=k;
Q[++tail]=next[now][i];
addedge(k,next[now][i]);
}
}
}
void dfs(int u){
L[u]=++id;
for(int i=head[u];i!=-1;i=e[i].next){
int v=e[i].to;
dfs(v);
}
R[u]=id;
}
int sum(int x){
int ret=0;
while(x>0) ret+=C[x],x-=(x&-x);
return ret;
}
void add(int x,int v){
while(x<=sz){
C[x]+=v;
x+=(x&-x);
}
}
void solve(){
int now=root,cnt=0,len=strlen(s),top=0,ID=1;
for(int i=0;i<len;i++){
if(s[i]=='B'){
add(L[S[top--]],-1);
now=S[top];
continue;
}
if(s[i]=='P'){
++cnt;
while(ID<=q&&a[ID].y==cnt){ //a[ID].x表示第几个删除
ans[a[ID].id]=sum(R[ ed[a[ID].x] ])-sum(L[ ed[a[ID].x] ]-1);
ID++;
}
continue;
}
now=next[now][s[i]-'a'];
add(L[now],1);
S[++top]=now;
}
}
};
Trie ac;
bool cmp(const node& u,const node& v){
if(u.y!=v.y) return u.y<v.y;
return u.x<v.x;
}
int main(){
ac.init();
init();
scanf("%s",s);
ac.insert();
ac.build();
ac.dfs(0);
scanf("%d",&q);
for(int i=1;i<=q;i++) scanf("%d%d",&a[i].x,&a[i].y),a[i].id=i;
sort(a+1,a+q+1,cmp);
ac.solve();
for(int i=1;i<=q;i++) printf("%d\n",ans[i]);
return 0;
}