题意:
给定一个 \(n\) 个节点的 trie 树,求 \(m\) 对节点作为结束位置代表的字符串之间的包含关系。
题解:
对 trie 树建 AC 自动机。根据 fail 指针的含义,我们发现,字符串 \(a\) 在字符串 \(b\) 中出现的次数等于 \(b\) 中 fail 指针直接或间接指向 \(a\) 的结束节点的节点数。
所以,我们用 fail 指针的反向边建出 fail 树,则这个数字又可以转化为 \(a\) 的结束节点的 fail 树上属于 \(b\) 的节点数量。
如果根据 fail 树的 dfs 序构造出 dfs 序数组,则 \(a\) 的结束节点的子树一定是这个数组中连续的一段,对于每一个询问 \(a,b\),问题又可以转变为区间求和。
因此,我们想到了一种离线的做法:我们以 trie 上的 dfs 序遍历每个 \(b\),并维护 \(b\) 中的节点在这个数组上出现的情况,对于每一对 \(a,b\) 的询问,就是对 \(a\) 的结束节点在 dfs 序数组上的开始位置和结束位置求区间和,可以用树状数组维护。
下面是离线做法的代码:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<vector>
#include<map>
using namespace std;
const int maxn=1e5+5;
struct BiTree{
int val[maxn];
BiTree(){
memset(val,0,sizeof(val));
}
void initiate(){
memset(val,0,sizeof(val));
}
inline int lowbit(int n){
return n&(-n);
}
void add(int n,int v){
int i;
for (i=n;i<maxn;i+=lowbit(i))
val[i]+=v;
}
int query(int n){
int i,ans;
ans=0;
for (i=n;i>0;i-=lowbit(i))
ans+=val[i];
return ans;
}
} bi_tree;
struct AcAuto{
const static int sigma=26;
static int que[maxn];
int ch[maxn][sigma],fa[maxn],fail[maxn];
int pos[maxn];
int ndcnt,strcnt,root;
int in[maxn],out[maxn],ope[maxn];
vector<int> next[maxn];
map<int,int> query[maxn];
AcAuto(){
ndcnt=0;
strcnt=0;
}
int new_node(){
++ndcnt;
memset(ch[ndcnt],0,sizeof(ch[ndcnt]));
fa[ndcnt]=0; fail[ndcnt]=0;
return ndcnt;
}
void initiate(){
ndcnt=0;
strcnt=0;
root=new_node();
}
inline int index(char ch){
return ch-'a';
}
void add_str(char s[]){
int u=root;
for (int i=0;s[i];++i){
if (s[i]=='B'){
u=fa[u];
continue;
}
if (s[i]=='P'){
pos[++strcnt]=u;
continue;
}
int t=index(s[i]);
if (!ch[u][t]){
int v=new_node();
ch[u][t]=v;
fa[v]=u;
}
u=ch[u][t];
}
}
void build_fail(){
int l=0,r=0,u=root,v,w;
for (int i=0;i<sigma;++i) if (ch[u][i]){
v=ch[u][i];
que[r++]=v;
fail[v]=root;
next[root].push_back(v);
}
while (l<r){
u=que[l++];
for (int i=0;i<sigma;++i) if (ch[u][i]){
v=ch[u][i];
que[r++]=v;
w=fail[u];
while (w&&!ch[w][i]) w=fail[w];
fail[v]=w?ch[w][i]:root;
next[fail[v]].push_back(v);
}
}
}
void fail_dfs(int u,int &time){
in[u]=++time;
for (int i=0;i<next[u].size();++i){
int v=next[u][i];
fail_dfs(v,time);
}
out[u]=time;
}
void trie_dfs(int u){
bi_tree.add(in[u],1);
map<int,int>::iterator it;
for (it=query[u].begin();it!=query[u].end();it++)
it->second=bi_tree.query(out[it->first])
-bi_tree.query(in[it->first]-1);
for (int i=0;i<sigma;i++) if (ch[u][i]){
trie_dfs(ch[u][i]);
}
bi_tree.add(in[u],-1);
}
void build_dfs(){
int time=0;
fail_dfs(root,time);
trie_dfs(root);
}
} ac_auto;
int AcAuto::que[maxn];
char s[maxn];
int x[maxn],y[maxn];
int main(){
int n;
ac_auto.initiate();
scanf("%s",s);
ac_auto.add_str(s);
ac_auto.build_fail();
scanf("%d",&n);
for (int i=0;i<n;i++){
scanf("%d%d",&x[i],&y[i]);
ac_auto.query[ac_auto.pos[y[i]]][ac_auto.pos[x[i]]]=0;
}
ac_auto.build_dfs();
for (int i=0;i<n;i++){
printf("%d\n",ac_auto.query[ac_auto.pos[y[i]]][ac_auto.pos[x[i]]]);
}
return 0;
}
另外,还有一种对每一组询问在线的做法:我们可以用可持久化线段树来维护保存每一个字符串中的字符在 dfs 序数组中的出现情况,再对每一对 \(a,b\) 求区间和即可。但是,这种做法在原题中是不能通过的。(因为卡了内存,甚至连指针实现的 AC 自动机都会卡到)
下面是这种写法的代码:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<vector>
using namespace std;
const int maxn=1e5+5;
struct SegmentTree{
const static int maxd=maxn*20*2;
int ls[maxd],rs[maxd];
int l[maxd],r[maxd],sum[maxd];
int ndcnt,opecnt,root[maxn];
SegmentTree(){
ndcnt=0;
opecnt=0;
}
int new_node(){
++ndcnt;
l[ndcnt]=r[ndcnt]=0;
ls[ndcnt]=rs[ndcnt]=0;
sum[ndcnt]=0;
return ndcnt;
}
void build_tree(int u,int ll,int rr){
l[u]=ll; r[u]=rr;
if (ll==rr) return;
int mm=(ll+rr)>>1;
ls[u]=new_node();
build_tree(ls[u],ll,mm);
rs[u]=new_node();
build_tree(rs[u],mm+1,rr);
}
void initiate(){
ndcnt=0;
opecnt=0;
root[0]=new_node();
build_tree(root[0],0,maxn);
}
void build_chain(int u,int v,int pos,int val){
ls[v]=ls[u]; rs[v]=rs[u];
l[v]=l[u]; r[v]=r[u]; sum[v]=sum[u]+val;
if (l[v]==r[v]) return;
int m=l[v]+r[v]>>1;
if (pos<=m){
ls[v]=new_node();
build_chain(ls[u],ls[v],pos,val);
}
else{
rs[v]=new_node();
build_chain(rs[u],rs[v],pos,val);
}
}
void add(int pos,int val){
root[++opecnt]=new_node();
build_chain(root[opecnt-1],root[opecnt],pos,val);
}
int query(int u,int ll,int rr){
if (l[u]>=ll&&r[u]<=rr) return sum[u];
if (l[u]>rr||r[u]<ll) return 0;
return query(ls[u],ll,rr)+query(rs[u],ll,rr);
}
} segment_tree;
struct AcAuto{
const static int sigma=26;
static int que[maxn];
int ch[maxn][sigma],fa[maxn],fail[maxn];
int pos[maxn];
int ndcnt,strcnt,root;
int in[maxn],out[maxn],ope[maxn];
vector<int> next[maxn];
AcAuto(){
ndcnt=0;
strcnt=0;
}
int new_node(){
++ndcnt;
memset(ch[ndcnt],0,sizeof(ch[ndcnt]));
fa[ndcnt]=0; fail[ndcnt]=0;
return ndcnt;
}
void initiate(){
ndcnt=0;
strcnt=0;
root=new_node();
}
inline int index(char ch){
return ch-'a';
}
void add_str(char s[]){
int u=root;
for (int i=0;s[i];++i){
if (s[i]=='B'){
u=fa[u];
continue;
}
if (s[i]=='P'){
pos[++strcnt]=u;
continue;
}
int t=index(s[i]);
if (!ch[u][t]){
int v=new_node();
ch[u][t]=v;
fa[v]=u;
}
u=ch[u][t];
}
}
void build_fail(){
int l=0,r=0,u=root,v,w;
for (int i=0;i<sigma;++i) if (ch[u][i]){
v=ch[u][i];
que[r++]=v;
fail[v]=root;
next[root].push_back(v);
}
while (l<r){
u=que[l++];
for (int i=0;i<sigma;++i) if (ch[u][i]){
v=ch[u][i];
que[r++]=v;
w=fail[u];
while (w&&!ch[w][i]) w=fail[w];
fail[v]=w?ch[w][i]:root;
next[fail[v]].push_back(v);
}
}
}
void trie_dfs(int u){
segment_tree.add(in[u],1);
ope[u]=segment_tree.opecnt;
for (int i=0;i<sigma;i++) if (ch[u][i]){
trie_dfs(ch[u][i]);
}
segment_tree.add(in[u],-1);
}
void fail_dfs(int u,int &time){
in[u]=++time;
for (int i=0;i<next[u].size();++i){
int v=next[u][i];
fail_dfs(v,time);
}
out[u]=time;
}
void build_dfs(){
int time=0;
fail_dfs(root,time);
trie_dfs(root);
}
} ac_auto;
int AcAuto::que[maxn];
char s[maxn];
int main(){
int n;
int x,y;
ac_auto.initiate();
scanf("%s",s);
ac_auto.add_str(s);
ac_auto.build_fail();
segment_tree.initiate();
ac_auto.build_dfs();
scanf("%d",&n);
for (int i=0;i<n;i++){
scanf("%d%d",&x,&y);
printf("%d\n",segment_tree.query(
segment_tree.root[ac_auto.ope[ac_auto.pos[y]]],
ac_auto.in[ac_auto.pos[x]],
ac_auto.out[ac_auto.pos[x]]
));
}
return 0;
}