题目链接:https://www.luogu.org/problemnew/show/P3181
广义后缀自动机,基本类似于将多个串用分隔符连一起建后缀自动机,不过可能在空间上有优化(不用分隔符?串超级多分隔符不够的时候就必须用它了...),建法网上看到有两种,其实复杂度什么都是一样的,不过一种靠多判一种情况可能在实际使用空间上有一些常数的优化,并且更严谨一些。本质都是建完一个串将las点移回到根,然后比普通后缀自动机多判断当前插入np结点信息在之前串中出现过完全一样或者出现过部分一样的。
那种多判情况的的基本上就是在存在一个之前节点np‘,它包含的信息比我当前要的实际np多的情况时,用那个拷贝分裂节点的方法搞一个新的。而没有判的则是之间按普通后缀自动机的构建插入,会多一个np和nt,稍微画一下发现与之前判掉的实际是一样的,不过多用了1个点(这个点其实没有啥意义)...虽然少判一种情况感觉复杂度上界没什么区别,但有时候会出一些小锅(比如桶排时一定要从小到大赋rnk,因为有些mx相同的点会存在父子关系),所以建议还是写多判情况的。
代码:
判的写法:
// luogu-judger-enable-o2
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=8e5+100;
const int M=2e5+100;
struct Sam{
Sam(){las=tot=1;}
int nxt[N][26],fa[N],sz[N][2],mx[N],las,tot;
void ins(int x,int op)
{
int p=las,np=nxt[las][x],t,nt;
if(np)
{
if(mx[p]+1==mx[np])las=np,sz[np][op]=1;
else
{
nt=++tot;
mx[nt]=mx[p]+1;
memcpy(nxt[nt],nxt[np],sizeof nxt[np]);
fa[nt]=fa[np],fa[np]=nt;
while(p&&nxt[p][x]==np)nxt[p][x]=nt,p=fa[p];
sz[nt][op]=1,las=nt;
}
return;
}
np=++tot,mx[np]=mx[p]+1,las=np,sz[np][op]++;
while(p&&!nxt[p][x])nxt[p][x]=np,p=fa[p];
if(!p){fa[np]=1;return;}
t=nxt[p][x];
if(mx[t]==mx[p]+1)fa[np]=t;
else
{
nt=++tot;mx[nt]=mx[p]+1;
memcpy(nxt[nt],nxt[t],sizeof nxt[t]);
fa[nt]=fa[t],fa[t]=fa[np]=nt;
while(p&&nxt[p][x]==t)nxt[p][x]=nt,p=fa[p];
}
}
int tax[M],rnk[N];
void sol()
{
ll ans=0;
for(int i=1;i<=tot;i++)tax[mx[i]]++;
for(int i=1;i<M;i++)tax[i]+=tax[i-1];
for(int i=1;i<=tot;i++)rnk[tax[mx[i]]--]=i;
for(int ii=tot,i;ii>=1;ii--)
{
i=rnk[ii];
sz[fa[i]][0]+=sz[i][0];
sz[fa[i]][1]+=sz[i][1];
ans+=1LL*(mx[i]-mx[fa[i]])*sz[i][0]*sz[i][1];
}
printf("%lld\n",ans);
}
}sam;
int n;
char S[N];
int main()
{
scanf("%s",S+1),n=strlen(S+1);
for(int i=1;i<=n;i++)sam.ins(S[i]-'a',0);
scanf("%s",S+1),n=strlen(S+1);
sam.las=1;
for(int i=1;i<=n;i++)sam.ins(S[i]-'a',1);//cerr<<sam.tot<<'\n';
sam.sol();
}
不判写法:
// luogu-judger-enable-o2
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=8e5+100;
const int M=2e5+100;
struct Sam{
Sam(){las=tot=1;}
int nxt[N][26],fa[N],sz[N][2],mx[N],las,tot;
void ins(int x,int op)
{
int p=las,np=nxt[las][x],t,nt;
if(np&&mx[p]+1==mx[np])
{
las=np,sz[np][op]=1;
// else
// {
// nt=++tot;
// mx[nt]=mx[p]+1;
// memcpy(nxt[nt],nxt[np],sizeof nxt[np]);
// fa[nt]=fa[np],fa[np]=nt;
// while(p&&nxt[p][x]==np)nxt[p][x]=nt,p=fa[p];
// sz[nt][op]=1,las=nt;
// }
return;
}
np=++tot,mx[np]=mx[p]+1,las=np,sz[np][op]++;
while(p&&!nxt[p][x])nxt[p][x]=np,p=fa[p];
if(!p){fa[np]=1;return;}
t=nxt[p][x];
if(mx[t]==mx[p]+1)fa[np]=t;
else
{
nt=++tot;mx[nt]=mx[p]+1;
memcpy(nxt[nt],nxt[t],sizeof nxt[t]);
fa[nt]=fa[t],fa[t]=fa[np]=nt;
while(p&&nxt[p][x]==t)nxt[p][x]=nt,p=fa[p];
}
}
int tax[M],rnk[N];
void sol()
{
ll ans=0;
for(int i=1;i<=tot;i++)tax[mx[i]]++;
for(int i=1;i<M;i++)tax[i]+=tax[i-1];
for(int i=1;i<=tot;i++)rnk[tax[mx[i]]--]=i;
for(int ii=tot,i;ii>=1;ii--)
{
i=rnk[ii];
sz[fa[i]][0]+=sz[i][0];
sz[fa[i]][1]+=sz[i][1];
ans+=1LL*(mx[i]-mx[fa[i]])*sz[i][0]*sz[i][1];
}
printf("%lld\n",ans);
}
}sam;
int n;
char S[N];
int main()
{
scanf("%s",S+1),n=strlen(S+1);
for(int i=1;i<=n;i++)sam.ins(S[i]-'a',0);
scanf("%s",S+1),n=strlen(S+1);
sam.las=1;
for(int i=1;i<=n;i++)sam.ins(S[i]-'a',1);//cerr<<sam.tot<<'\n';
sam.sol();
}
直接插分隔符暴力拼接写法:(注意下这种写法桶排的时候桶要开两个串总长)
// luogu-judger-enable-o2
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=8e5+100;
const int M=2e5+100;
struct Sam{
Sam(){las=tot=1;}
int nxt[N][27],fa[N],sz[N][2],mx[N],las,tot;
void ins(int x,int op)
{
int p=las,np=++tot,t,nt;
mx[np]=mx[p]+1,las=np,sz[np][op]++;
while(p&&!nxt[p][x])nxt[p][x]=np,p=fa[p];
if(!p){fa[np]=1;return;}
t=nxt[p][x];
if(mx[t]==mx[p]+1)fa[np]=t;
else
{
nt=++tot;mx[nt]=mx[p]+1;
memcpy(nxt[nt],nxt[t],sizeof nxt[t]);
fa[nt]=fa[t],fa[t]=fa[np]=nt;
while(p&&nxt[p][x]==t)nxt[p][x]=nt,p=fa[p];
}
}
int tax[N],rnk[N];
void sol()
{
ll ans=0;
for(int i=1;i<=tot;i++)tax[mx[i]]++;
for(int i=1;i<=tot;i++)tax[i]+=tax[i-1];
for(int i=1;i<=tot;i++)rnk[tax[mx[i]]--]=i;
for(int ii=tot,i;ii>=1;ii--)
{
i=rnk[ii];
sz[fa[i]][0]+=sz[i][0];
sz[fa[i]][1]+=sz[i][1];
ans+=1LL*(mx[i]-mx[fa[i]])*sz[i][0]*sz[i][1];
}
printf("%lld\n",ans);
}
}sam;
int n;
char S[N];
int main()
{
scanf("%s",S+1),n=strlen(S+1);
for(int i=1;i<=n;i++)sam.ins(S[i]-'a',0);
scanf("%s",S+1),n=strlen(S+1);
sam.ins(26,0);
for(int i=1;i<=n;i++)sam.ins(S[i]-'a',1);//cerr<<sam.tot<<'\n';
sam.sol();
}