Description
你有一个字符串S,一开始为空串,要求支持两种操作
在S后面加入字母C
删除S最后一个字母
问每次操作后S有多少个两两不同的连续子串
Solution
似乎暴力也能过的样子
一个显然的做法就是建后缀平衡树,但是好像有点难写啊
考虑离线,给出的串刚好就是一棵Trie,我们按照Trie建广义SAM之后模拟就可以了
具体说就是每次都在处理Trie上的一条链,答案就是这些点在parent树上到根形成链的并的长度之和
因为每次只会插入/删除一个点,那么不妨用set维护虚树,这样就是一个log的了
Code
#include <stdio.h>
#include <string.h>
#include <algorithm>
#include <set>
#define rep(i,st,ed) for (int i=st;i<=ed;++i)
#define copy(x,t) memcpy(x,t,sizeof(x))
typedef long long LL;
const int N=400005;
char str[N];
struct SAM {
struct edge {int y,next;} e[N];
int size[N],dep[N],bl[N],pos[N],dfn[N];
int rec[N][26],len[N],tot;
int ls[N],fa[N],edCnt;
SAM() {tot=1;}
void add_edge(int x,int y) {
e[++edCnt]=(edge) {y,ls[x]}; ls[x]=edCnt;
}
int extend(int p,int ch) {
int q,np,nq;
if (rec[p][ch]) {
q=rec[p][ch];
if (len[p]+1==len[q]) return q;
else {
nq=++tot; len[nq]=len[p]+1;
copy(rec[nq],rec[q]);
fa[nq]=fa[q];
fa[q]=nq;
for (;p&&rec[p][ch]==q;p=fa[p]) rec[p][ch]=nq;
return nq;
}
}
np=++tot; len[np]=len[p]+1;
for (;p&&!rec[p][ch];p=fa[p]) rec[p][ch]=np;
if (!p) fa[np]=1;
else {
q=rec[p][ch];
if (len[p]+1==len[q]) fa[np]=q;
else {
nq=++tot; len[nq]=len[p]+1;
copy(rec[nq],rec[q]);
fa[nq]=fa[q];
fa[np]=fa[q]=nq;
for (;p&&rec[p][ch]==q;p=fa[p]) rec[p][ch]=nq;
}
}
return np;
}
void dfs1(int x) {
size[x]=1;
for (int i=ls[x];i;i=e[i].next) {
if (e[i].y==fa[x]) continue;
fa[e[i].y]=x; dep[e[i].y]=dep[x]+1;
dfs1(e[i].y); size[x]+=size[e[i].y];
}
}
void dfs2(int x,int up) {
pos[x]=++pos[0]; dfn[pos[0]]=x;
bl[x]=up; int mx=0;
for (int i=ls[x];i;i=e[i].next) {
if (e[i].y!=fa[x]&&size[e[i].y]>size[mx]) mx=e[i].y;
}
if (!mx) return ;
dfs2(mx,up);
for (int i=ls[x];i;i=e[i].next) {
if (e[i].y!=fa[x]&&e[i].y!=mx) dfs2(e[i].y,e[i].y);
}
}
int get_lca(int x,int y) {
for (;bl[x]^bl[y];x=fa[bl[x]]) if (dep[bl[x]]<dep[bl[y]]) std:: swap(x,y);
return dep[x]<dep[y]?x:y;
}
} SAM;
int rec[N][26],fa[N],pos[N],tot=1;
void build(int x) {
rep(i,0,25) if (rec[x][i]) {
pos[rec[x][i]]=SAM.extend(pos[x],i);
build(rec[x][i]);
}
}
int main(void) {
scanf("%s",str+1);
int n=strlen(str+1),x=1;
rep(i,1,n) {
if (str[i]!='-') {
int ch=str[i]-'a';
if (!rec[x][ch]) rec[x][ch]=++tot,fa[tot]=x;
x=rec[x][ch];
} else x=fa[x];
}
build(pos[1]=1); LL ans=0; x=1;
rep(i,2,SAM.tot) SAM.add_edge(SAM.fa[i],i);
SAM.dfs1(SAM.dep[1]=1); SAM.dfs2(1,1);
std:: set <int> set;
rep(i,1,n) {
if (str[i]!='-') {
x=rec[x][str[i]-'a'];
int tx=0,ty=0;
set.insert(SAM.pos[pos[x]]);
std:: set <int>:: iterator it=set.find(SAM.pos[pos[x]]);
ans+=SAM.len[pos[x]];
it++; if (it!=set.end()) tx=SAM.dfn[*it]; it--;
if (it!=set.begin()) it--,ty=SAM.dfn[*it],it++;
if (tx) ans-=SAM.len[SAM.get_lca(tx,pos[x])];
if (ty) ans-=SAM.len[SAM.get_lca(ty,pos[x])];
if (tx&&ty) ans+=SAM.len[SAM.get_lca(tx,ty)];
} else {
int tx=0,ty=0;
std:: set <int>:: iterator it=set.find(SAM.pos[pos[x]]);
ans-=SAM.len[pos[x]];
it++; if (it!=set.end()) tx=SAM.dfn[*it]; it--;
if (it!=set.begin()) it--,ty=SAM.dfn[*it],it++;
if (tx) ans+=SAM.len[SAM.get_lca(tx,pos[x])];
if (ty) ans+=SAM.len[SAM.get_lca(ty,pos[x])];
if (tx&&ty) ans-=SAM.len[SAM.get_lca(tx,ty)];
x=fa[x]; set.erase(it);
}
printf("%lld\n", ans);
}
return 0;
}