题目大意
有一个非严格的Trie,有三种操作:新加一个子树、求Trie上本质不同的字符串个数、询问一个字符串在trie上的出现次数。
分析
很容易想到用SAM来做。
在线用LCT维护即可。
然而这题允许离线,这里我用了个离线的做法。
首先把最终的trie构出来,建个SAM,然后每当新加入节点就把相应的信息加进去。
现在看看如何处理两个询问:
1. 询问本质不同的字符串个数。这个显然等于当前SAM上所有节点表示的最大长度减fail树它的父亲表示的最大长度之和(step[x]-step[fail[x]])。当加入一个节点时,从它所对应SAM上的节点沿fail树往上跳,统计一下答案,并标记已访问。直到跳到访问过的节点。这一问的总时间复杂度是O(N)的。
2. 询问一个字符串的出现次数。这个也很简单,就是对应节点在fail树上对应的子树表示多少个原trie上的节点。可以用树状数组实现区间求和、单点加操作。
总复杂度就是O(NlogN)的了~
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int N=200005;
typedef long long LL;
int n,h[N],e[N],nxt[N],Dfn[N],tot,m,Size[N],Sum[N],Now,D[N],fa[N],cnt;
int Son[N][3],Id[N],Step[N],Fail[N];
bool Visit[N];
LL ans;
char c,C[N],Q[N*10];
struct Op
{
int typ,x;
}O[N];
int read()
{
int x=0,sig=1;
for (c=getchar();c<'0' || c>'9';c=getchar()) if (c=='-') sig=-1;
for (;c>='0' && c<='9';c=getchar()) x=x*10+c-48;
return x*sig;
}
void add(int x,int y,char c)
{
e[++tot]=y; nxt[tot]=h[x]; C[tot]=c; h[x]=tot;
}
void Init()
{
n=read(); n=read();
for (int i=1;i<n;i++)
{
int x=read(),y=read();
for (;c<'a' || c>'c';c=getchar());
add(x,y,c); add(y,x,c);
}
m=read();
for (int i=0;i<m;i++)
{
O[i].typ=read();
if (O[i].typ==2)
{
int r=read(); O[i].x=read();
for (int j=1;j<O[i].x;j++)
{
int x=read(),y=read();
for (;c<'a' || c>'c';c=getchar());
add(x,y,c); add(y,x,c);
}
}else if (O[i].typ==3)
{
for (;c<'a' || c>'z';c=getchar());
for (;c!='\n' && c!='\r';c=getchar())
{
O[i].x++;
Q[Now++]=c;
}
}
}
}
int Extend(int Last,char c)
{
c-='a';
if (Son[Last][c]>0 && Step[Son[Last][c]]==Step[Last]+1) return Son[Last][c];
int p=Last,np=++cnt,q,nq;
Step[np]=Step[p]+1;
for (;p>=0 && Son[p][c]==0;p=Fail[p]) Son[p][c]=np;
if (p<0) Fail[np]=0;else
{
q=Son[p][c];
if (Step[q]==Step[p]+1) Fail[np]=q;else
{
nq=++cnt;
Step[nq]=Step[p]+1;
memcpy(Son[nq],Son[q],sizeof(Son[q]));
Fail[nq]=Fail[q];
Fail[np]=Fail[q]=nq;
for (;p>=0 && Son[p][c]==q;p=Fail[p]) Son[p][c]=nq;
}
}
return np;
}
void Add(int x,int y)
{
e[++tot]=y; nxt[tot]=h[x]; h[x]=tot;
}
void Dfs(int x)
{
Dfn[x]=++tot;
for (int i=h[x];i;i=nxt[i])
{
Dfs(e[i]); Size[x]+=Size[e[i]]+1;
}
}
int Lowbit(int x)
{
return x&-x;
}
void Change(int x)
{
for (;x<=cnt+1;x+=Lowbit(x)) Sum[x]++;
}
int Get_sum(int x)
{
int s=0;
for (;x;x-=Lowbit(x)) s+=Sum[x];
return s;
}
void Ins(int x)
{
Change(Dfn[x]);
for (;x>0 && !Visit[x];x=Fail[x])
{
Visit[x]=1; ans+=Step[x]-Step[Fail[x]];
}
}
void Work()
{
tot=0; Fail[0]=-1;
D[1]=1;
for (int i=1,j=1,x,k;i<=j;i++)
{
x=D[i];
for (k=h[x];k;k=nxt[k]) if (e[k]!=fa[x])
{
fa[e[k]]=x;
Id[e[k]]=Extend(Id[x],C[k]);
D[++j]=e[k];
}
}
memset(h,0,sizeof(h));
for (int i=1;i<=cnt;i++) Add(Fail[i],i);
tot=Now=0;
Dfs(0);
for (int i=2;i<=n;i++) Ins(Id[i]);
for (int i=0;i<m;i++)
{
if (O[i].typ==1) printf("%lld\n",ans);
else if (O[i].typ==2)
{
for (int j=1;j<O[i].x;j++) Ins(Id[++n]);
}else
{
bool bz=1;
int j,x;
for (x=j=0;j<O[i].x;j++,Now++)
if (!Son[x][Q[Now]-'a']) bz=0;else x=Son[x][Q[Now]-'a'];
if (bz) printf("%d\n",Get_sum(Dfn[x]+Size[x])-Get_sum(Dfn[x]-1));
else printf("0\n");
}
}
}
int main()
{
Init();
Work();
return 0;
}