题目描述
题目大意:给一个字符串集合S、一堆操作。操作1就是在在另一个集合T中加入一个串P,操作2就是问你S中第x个串是T中多少个串的子串。
|S|,q<=1e5,S中串总长度L1和T中串的总长度L2分别<=2e6。
题解
考虑把S集合里的串那去建AC自动机,然后每次读入一个P,它的贡献就是从它在树上能匹配到的节点,沿着fail指针一直往上跳能走到的所有点。
于是直接对S集合的串建出fail树,然后每个匹配的贡献就是沿着fail树上的点到根的一条链。这里的fail树是AC机上所有节点构成的树,我们不关心上面是否有尾标记。
于是我们要求的是树链的并。我们可以直接在线用树链剖分来解决,但明显会被卡。所以我们可以搞出fail树的dfs序,然后用bit维护单点修改和区间查询。
这样树上每个点对应一段区间[L,R]。当前点的祖先的R大于它,L小于它,对于一个点在其L处打个标记。最后算答案时用R的前缀和减(L-1)的前缀得到区间内标记的数量。
但是这样标记会重复,一部分节点会被标记多次。我们将所有匹配到的节点丢进数组里按dfs序排序,然后在相邻的点的lca处将标记减去即可。这样一个点与相邻的点的公共部分是最长的,标记去重时不会出错,否则会出错。
于是总的复杂度不会超过NlogN,求LCA用RMQ会MLE,改为倍增。还有在bzoj测这题时会莫名CE,CE了好多次,最后直接弃坑去吃鸡终于还是AC了。话说这题用trie图写会短好多,但我没这样做。
AC自动机、fail树、BIT(dfs序)三个东西都存了S串的节点,注意节点分别所对应的编号,不要弄混了就好。
代码
#include <bits/stdc++.h>
#define maxL 2000100
#define maxN 100010
#define Lg 22
using namespace std;
char S[maxL], T[maxL];
int n, q, cur = -1, cnt, dfn;
int P[maxN], f[Lg][maxL], dep[maxL], Mark[maxL];
int st[maxL], ed[maxL], BIT[maxL];
struct AC{
AC *son[26], *fail;
int id;
}Node[maxL], *Root, *que[maxL];
AC *NewTnode(){
for(int i = 0; i < 26; i++) Node[cnt].son[i] = NULL;
Node[cnt].fail = NULL;
Node[cnt].id = cnt;
return Node+cnt++;
}
struct List{
List *next;
int obj;
}*head[maxL], Edg[maxL];
void Addedge(int a, int b){
Edg[++cur].next = head[a];
Edg[cur].obj = b;
head[a] = Edg+cur;
}
void Insert(char *x, int id){
AC *now = Root;
int len = strlen(x);
for(int i = 0; i < len; i++){
int pos = x[i] - 'a';
if(!now->son[pos]) now->son[pos] = NewTnode();
now = now->son[pos];
}
P[id] = now->id;
}
void Build(){
AC *now, *temp;
int hh = 0, tt = 0;
que[hh] = Root;
Root->fail = NULL;
for(int i = 0; i < cnt; i++) head[i] = NULL;
while(hh <= tt){
now = que[hh++];
for(int i = 0; i < 26; i++){
if(!now->son[i]) continue;
que[++tt] = now->son[i];
now->son[i]->fail = Root;
temp = now->fail;
while(temp && !temp->son[i]) temp = temp->fail;
if(temp) now->son[i]->fail = temp->son[i];
Addedge(now->son[i]->fail->id, now->son[i]->id);
}
}
}
int LCA(int x, int y){
if(dep[x] > dep[y]) swap(x, y);
for(int i = Lg-1; i >= 0; i--)
if(dep[f[i][y]] >= dep[x]) y = f[i][y];
if(x == y) return x;
for(int i = Lg-1; i >= 0; i--)
if(f[i][x] != f[i][y]){
x = f[i][x];
y = f[i][y];
}
return f[0][x];
}
int lowbit(int x){
return x & (-x);
}
void Add(int x, int v){
for(int i = x; i <= dfn; i += lowbit(i)) BIT[i] += v;
}
int Sum(int x){
int res = 0;
for(int i = x; i; i -= lowbit(i)) res += BIT[i];
return res;
}
bool cmp(int x, int y){
return st[x] < st[y];
}
void Find(char *x){
AC *now = Root;
Mark[0] = 0;
int len = strlen(x);
for(int i = 0; i < len; i++){
int pos = x[i] - 'a';
while(now != Root && !now->son[pos]) now = now->fail;
if(!now->son[pos]) continue;
now = now->son[pos];
Mark[++Mark[0]] = now->id;
}
sort(Mark+1, Mark+Mark[0]+1, cmp);
Add(st[Mark[1]], 1);
for(int i = 2; i <= Mark[0]; i++){
int lca = LCA(Mark[i-1], Mark[i]);
Add(st[lca], -1);
Add(st[Mark[i]], 1);
}
}
void Dfs(int x){
st[x] = ++dfn;
for(List *p = head[x]; p; p = p->next){
int v = p->obj;
dep[v] = dep[x] + 1;
f[0][v] = x;
Dfs(v);
}
ed[x] = dfn;
}
int main(){
scanf("%d", &n);
Root = NewTnode();
for(int i = 1; i <= n; i++){
scanf("%s", S);
Insert(S, i);
}
Build();
Dfs(0);
for(int i = 1; i < Lg; i++)
for(int j = 1; j <= dfn; j++)
f[i][j] = f[i-1][f[i-1][j]];
int op, x;
scanf("%d", &q);
for(int i = 1; i <= q; i++){
scanf("%d", &op);
if(op == 1){
scanf("%s", T);
Find(T);
}
else{
scanf("%d", &x);
printf("%d\n", Sum(ed[P[x]]) - Sum(st[P[x]]-1));
}
}
return 0;
}