Luogu P2414 [NOI2011] 阿狸的打字机
私以为Luogu题解的第一篇写的很好。。这里算是用我自己的理解重述一下吧。
First-Step:暴力
因为要匹配出现次数,而且是多个串的问题,会想到AC自动机。这里要求的是trie树上一个串在另一个串中出现的次数,也算是一个套路吧。记录一下trie树上每个位置代表了第几个串的结束,也记录每个串的结束位置。这样每次询问x在y中出现了多少次,就可以直接从y的end_pos开始,每个点都遍历它的fail指针,如果转移到一次x的end_pos 结果就+1.这样就实现了在一棵trie树上完成所有询问了。
这样复杂度有点高啊。
Second-Step:优化
对于相同的y,我们实际上是跑了好多次y这个串以及它的fail指针,然后每次处理的是不同的x罢了。
那这样,空间既然足够,我们就开个桶,对于相同的y只跑一遍,每遇到一个点就让它对应的tong ++,这样对于每个y得到了一个桶,然后这一段区间的答案就都知道了。
这样做要离线,对查询按y排序。然后统一处理。
Third-Step:转化思路
当前思路已经很难再继续优化了,我们要想办法转换一下求答案的方法。
之前我们求答案是从y往上找x,是倒着找,那么我们能不能先把y标记上,然后对于每个点x,求一下能转移到它的节点有几个点是y能转移到的。
倒着建fail树,然后dfs序处理一下。对于几个连续的y,我们先把y能转移到的点对应的dfs序的left标记上1,然后对于每个y的查询x,我们只需要查询x的子树中1的和,也就是查询区间和(dfs序上).维护用树状数组,常数小一点。
Foreth-Step:终极优化
当前思路已经足够优秀了,但是还是会TLE,那我们想一下哪里还有可能重复计算了呢?
答案当然是插入Y这个操作。很多的y其实本质上是没有什么差别的,比如aaaab和aaaac,我们在aaa这条路径上多跑了一次。我们要想办法优化这一步。
之前我们记录了每个点的起始位置和每个串对应在trie树上的节点。
现在,我们直接对整个fail树跑dfs。遇到一个新的点,我们就把它dfs序的左端点+1,遇到一个end点,我们看一下它是否对应询问中的一个区间,如果对应着,我们再像刚才一样,遍历它对应的x,然后用树状数组查询区间和。这样就再没有重复计算了。
Tips:
这样做的前提好像是不能有重复的字符串。如果出现重复的,比如说y1,y2,trie中我们统计的节点对应字符串会出现两个,这时y2覆盖了y1,当出现查询x,y1时,由于在trie树中没有y1,所以我们找不到x,y1对应的答案,这时就错了。
解决方法是对于每个节点维护一个multi_set或者vector,然后对于一个节点的所有终止编号都遍历一遍,因为输入只有10W个,所以这样复杂度还是一样的,用vector是省空间。速度也不会慢多少。
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<cmath>
#include<algorithm>
#include<queue>
#include<stack>
#include<map>
#include<ctime>
#define up(i,x,y) for(int i = x;i <= y;i ++)
#define down(i,x,y) for(int i = x;i >= y;i --)
#define mem(a,b) memset((a),(b),sizeof(a))
#define mod(x) ((x)%MOD)
#define lson p<<1
#define rson p<<1|1
using namespace std;
typedef long long ll;
const int SIZE = 200010;
const int INF = 2147483640;
const double eps = 1e-8;
inline void RD(int &x)
{
x = 0; char c; c = getchar();
bool flag = 0;
if(c == '-') flag = 1;
while(c < '0' || c > '9') {if(c == '-') {flag = 1;} c = getchar();}
while(c >= '0' && c <= '9') x = (x << 1) + (x << 3) + c - '0',c = getchar();
}
int trie[SIZE][26],root,tot,sum[SIZE],fa[SIZE],id[SIZE];
queue <int> q;
int fail[SIZE];
//仿照KMP算法的Next数组,我们也对树上的每一个节点建立一个前缀指针。这个前缀指针的定义和KMP算法中的next数组相类似,
//从根节点沿边到节点p我们可以得到一个字符串SX,节点p的前缀指针定义为:指向树中出现过的S的最长的后缀(不能等于S)。
void build_ac()
{
fail[0] = 0;//根节点的失败指针指向本身
for(int i = 0;i < 26;i ++)
{
int u = trie[0][i];
if(u)
{
q.push(u);
fail[u] = 0;//根节点指向的节点的fail指针只可能指向根节点
}
}
while(!q.empty())
{
int f = q.front();//取队首
q.pop();
for(int i = 0;i < 26;i ++)
{
int u = trie[f][i];
if(!u) continue;//没有i儿子 不管
q.push(u);
int v = fail[f];//父亲的fail指针
while(v && !trie[v][i]) v = fail[v];
//如果v不是0(根节点)并且从v到i没有边(没有匹配成功)
//失败指针前移
fail[u] = trie[v][i];
//注意 此时failu变成了 与父亲相同字母的fail指针的 它的儿子字母为i的 点的编号
}
}
}
struct Edge{
int to;
}edges[SIZE*2];
int head[SIZE],nextt[SIZE*2],tot_graph;
void build(int f,int t)
{
edges[++tot_graph].to = t;
nextt[tot_graph] = head[f];
head[f] = tot_graph;
}
void build_graph()
{
for(int i = 1;i <= tot;i ++)
{
int f = fail[i];
build(f,i);//父亲指向儿子 倒着建fail树
}
}
int dfs_clock,l[SIZE],r[SIZE];
void dfs_get_clock(int u)
{
l[u] = ++dfs_clock;
for(int i = head[u];i;i = nextt[i])
{
dfs_get_clock(edges[i].to);
}
r[u] = dfs_clock;
}
namespace BIT{
const int N = 200000;
int bit[N];
int lowbit(int x){return x&-x;};
void add(int x,int v){while(x<N)bit[x]+=v,x+=lowbit(x);}
int sum(int x){int r=0;while(x)r+=bit[x],x-=lowbit(x);return r;}
}
char s[SIZE];
int n;
int endd[SIZE];//每个字符串对应的终止节点
struct Ques{
int x,y,id,ans;
}quest[SIZE];
int qu_l[SIZE],qu_r[SIZE];//相同的y的左右端点
bool cmp(Ques &a,Ques &b){return a.y < b.y;}
bool cmp2(Ques &a,Ques &b){return a.id < b.id;}
void solve(int x)//遍历trie树
{
BIT::add(l[x],1);//把这个点置为1
if(id[x])//这个点是一个结尾点
{
for(int i = qu_l[id[x]];i <= qu_r[id[x]];i ++)//处理这个区间内的所有询问
{
quest[i].ans = BIT::sum(r[endd[quest[i].x]]) - BIT::sum(l[endd[quest[i].x]] - 1);
}
}
for(int i = 0;i < 26;i ++)
{
if(trie[x][i]) solve(trie[x][i]);
}
BIT::add(l[x],-1);
}
int main(int argc, char const *argv[])
{
scanf("%s",s);
int l = strlen(s);
root = 0;
for(int i = 0;i < l;i ++)
{
if(s[i] >= 'a' && s[i] <= 'z')
{
int x = s[i]-'a';
if(!trie[root][x]) trie[root][x] = ++tot,fa[trie[root][x]] = root;
root = trie[root][x];
}
if(s[i] == 'B') root = fa[root];
if(s[i] == 'P') endd[++n] = root,id[root] = n;
}
build_ac();
build_graph();
dfs_get_clock(0);
int q;
scanf("%d",&q);
for(int i = 1;i <= q;i ++)
{
scanf("%d%d",&quest[i].x,&quest[i].y);
quest[i].id = i;
}
sort(quest+1,quest+1+q,cmp);
int pos = 1;//跳转
for(int i = 1;i <= q;i = pos)
{
qu_l[quest[i].y] = i;
while(quest[pos].y == quest[i].y) pos ++;
qu_r[quest[i].y] = pos-1;
}
solve(0);
sort(quest+1,quest+1+q,cmp2);
for(int i = 1;i <= q;i ++)
{
printf("%d\n",quest[i].ans);
}
return 0;
}