luogu P4081
感觉不到黑题的难度,做完这道题目,对SAM添加字符操作有了更加深刻的理解。
题目描述:
给你一个整数n,然后给你n个串,要求求出n个串每个的只属于其的本质不同的非空字串的数量。
所有串的总长度不超过1e5,n不超过1e5
不知道每条串的最大长度,1e5*1e5的数组太大
可以用一个1e5的数组存 然后记录下每条串的长度
如果想用SAM来做这道题目我们要清楚SAM的原理
要了解SAM添加字符的时候到底在干嘛
(搞懂之后就知道广义SAM为啥成立了)
inline void addChar(int c){
int p=last;
int np=newNode(node[p].len+1, node[p].pos+1);
while(p && node[p].nxt[c]==0)
node[p].nxt[c]=np, p=node[p].fail;
if (p==0) node[np].fail=root;
else{
int q=node[p].nxt[c];
if (node[p].len+1 == node[q].len){
node[np].fail=q;
}
else
{
int nq=newNode(node[p].len+1, node[q].pos);
for (int i=0; i<kind; i++)
node[nq].nxt[i]=node[q].nxt[i];
node[nq].fail=node[q].fail;
node[q].fail=node[np].fail=nq;
while(p && node[p].nxt[c]==q)
node[p].nxt[c]=nq, p=node[p].fail;
}
}
last=np;
node[np].cnt=1;
}
首先找到尾结点,然后必定要新建一个节点(np),这个节点(np)也将成为新的尾结点(添加最后将last置为np),它保证了所有节点的最大长度等于加入字符的总数量(可以这样理解,每一个曾被置为last的节点,都是添加这个字符时,当时最长长度最大的节点),这个节点的父亲节点如果没有一条边(边为当前添加字符)那就为其父亲节点添加,并向上遍历父亲节点的父亲节点,直到有一个节点有一条边(边为当前添加字符)或者遍历到0节点。
例如下面:
SAM中的已经添加的串为ababa
(1)如果要向其中添加b ab已经出现过,所以向上遍历父亲节点,会出现一个节点有代表b的一条边连向其他点。
(2)如果要向其中添加c ac没有出现过,那么会一直向上遍历父亲节点,直到父亲节点=0(根节点为1,父亲节点等于0相当于结束)
如果向上沿着fail树遍历的节点p等于0
那np的fail就是root 且明显不需要新建节点(因为相当于重新加入了一个从未出现过的字符)
当p!=0时
就需要分情况讨论了
令q为p节点沿添加字符的边走向的节点
如果p.len+1等于q.len根据后缀自动机的定义 这两个节点其实也是一样的 不需要新建节点
但如果不等于
我们就需要新建节点了
其实出现这种情况就是因为p表示的长度要大于q这个节点的长度+1
例子:q节点表示的最长串为ababa
p节点表示的最长串为cccababac
我们需要新建一个节点nq使其表示的最长长度的串为ababac,且q以及其父亲节点经过新添加字符的边走向nq,并将p节点的父亲置为nq。(当我们操作多个字符串时,如果有这种情况新建的一个点nq,且p属于其他的串时,那这个节点nq所代表的所有子串就没有了贡献)。
理解了SAM添加字符时候的情况
我们将所有的串存入SAM(每次新串输入last置为1),这个是广义SAM和SAM的区别,因为我们再把last置为1后,之后申请的节点就如果和之前的串有公共子串,那么该子串对应的状态节点和我们之前前一条串申请的节点之间,会有fail边连接。
然后我们设置一个访问数组vis(大小为字符串长度的两倍(SAM性质))
这个数组记录了谁访问过该节点
然后把每个串再跑一遍SAM因为之前输入过,所以串每个字符跑的时候一定有匹配节点,我们把当前状态节点s,以及当先串是第几串a,然后沿着fail树遍历s的父亲和祖先,如果节点未被访问过,将其vis置为a,如果其被其他节点访问过,将其vis置为-1。如果vis已经为-1了,那就不需要向上遍历了,因为当前节点已经被至少两个串访问过,其祖先节点必定也被至少两个串访问过。
之后我们只要找到vis值不是-1的节点,然后对其代表的子串的数量进行统计然后输出即可。
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
namespace SAM
{
const int maxn=2e5;
const int kind=26;
struct Node{
int nxt[kind], fail;
int len; // Max Length of State
int pos; // Appear Position of State, Indexed From 1
int cnt; // Appear Count of State
}node[maxn*2];
int numn, last, root;
inline int newNode(int l, int p){
int x=++numn;
for (int i=0; i<kind; i++) node[x].nxt[i]=0;
node[x].cnt=node[x].fail=0;
node[x].len=l;
node[x].pos=p;
return x;
}
inline void init(){
root=last=newNode(numn=0, 0);
}
inline void addChar(int c){
int p=last;
int np=newNode(node[p].len+1, node[p].pos+1);
while(p && node[p].nxt[c]==0)
node[p].nxt[c]=np, p=node[p].fail;
if (p==0) node[np].fail=root;
else{
int q=node[p].nxt[c];
if (node[p].len+1 == node[q].len){
node[np].fail=q;
}
else
{
int nq=newNode(node[p].len+1, node[q].pos);
for (int i=0; i<kind; i++)
node[nq].nxt[i]=node[q].nxt[i];
node[nq].fail=node[q].fail;
node[q].fail=node[np].fail=nq;
while(p && node[p].nxt[c]==q)
node[p].nxt[c]=nq, p=node[p].fail;
}
}
last=np;
node[np].cnt=1;
}
}
using namespace SAM;
long long ans[maxn];
char s[maxn],ss[maxn];
int len[maxn],vis[maxn];
int n;
inline void cla(int x,int y){
for(;x&&vis[x]!=y&&vis[x]!=-1;x=node[x].fail){
if(vis[x]!=0)vis[x]=-1;
else vis[x]=y;
}
}
void solve(){
int tot=0;
for(int i=1;i<=n;i++){
for(int j=1,x=1;j<=len[i];j++){
int v=s[tot]-'a';
tot++;
cla(x=node[x].nxt[v],i);
}
}
for(int i=1;i<=numn;i++)
{
if(vis[i]!=-1){
int x=vis[i];
ans[x]+=node[i].len-node[node[i].fail].len;
}
}
for(int i=1;i<=n;i++)printf("%lld\n",ans[i]);
}
int main(){
scanf("%d",&n);
int tot=0;
init();
memset(len,0,sizeof(len));
memset(vis,0,sizeof(vis));
memset(ans,0,sizeof(ans));
for(int i=1;i<=n;i++){
scanf("%s",ss);
last=1;//最核心的一句
int l=strlen(ss);
for(int j=0;j<l;j++){
len[i]++;
addChar(ss[j]-'a');
s[tot++]=ss[j];
}
}
s[tot]='\0';
solve();
return 0;
}
上一版代码很丑,而且其实因为是只是统计本质不同的子串,所以这个版本的写法是可以的,但是这种广义SAM写法会申请很多多余的节点,宏观感受下{ab,abc}
简易+不申请多余节点版
#include<bits/stdc++.h>
using namespace std;
const int maxn=5e5+150;
const int kind=26;
typedef long long ll;
int tot1=1,las=1;
int ch[maxn*2][kind];
int len[maxn*2],fa[maxn*2];
char s1[maxn],s2[maxn];
ll sum[maxn*2];
int d1[maxn*2],d2[maxn*2];
int n;
inline int newn(int x){len[++tot1]=x;return tot1;}
inline int newnq(int p,int w){
int nq=newn(len[p]+1);
int q=ch[p][w];
for(int i=0;i<kind;i++)ch[nq][i]=ch[q][i];
fa[nq]=fa[q];
fa[q]=nq;
while(p&&ch[p][w]==q)ch[p][w]=nq,p=fa[p];
return nq;
}
void sam_ins(int c){
int p=las;
if(ch[p][c]){
int q=ch[p][c];
if (len[q]==len[p]+1)las=q;
else las=newnq(p,c);
return ;
}
int np=newn(len[las]+1);las=tot1;
while(p&&!ch[p][c])ch[p][c]=np,p=fa[p];
if(!p)fa[np]=1;
else{
int q=ch[p][c];
if(len[q]==len[p]+1) fa[np]=q;
else{
fa[np]=newnq(p,c);
}
}
}
int vis[maxn*2];
inline void cla(int x,int y){
for(;x&&vis[x]!=y&&vis[x]!=-1;x=fa[x]){
if(vis[x]!=0)vis[x]=-1;
else vis[x]=y;
}
}
ll ans[maxn];
int le[maxn];
void solve(){
int dd=0;
for(int i=1;i<=n;i++){
for(int j=1,x=1;j<=le[i];j++){
int v=s2[dd]-'a';
dd++;
cla(x=ch[x][v],i);
}
}
for(int i=1;i<=tot1;i++)
{
if(vis[i]!=-1){
int x=vis[i];
ans[x]+=len[i]-len[fa[i]];
}
}
for(int i=1;i<=n;i++)printf("%lld\n",ans[i]);
}
int main(){
scanf("%d",&n);
int tot=0;
for(int i=1;i<=n;i++){
scanf("%s",s1);
las=1;
int l=strlen(s1);
for(int j=0;j<l;j++){
le[i]++;
sam_ins(s1[j]-'a');
s2[tot++]=s1[j];
}
}
s2[tot]='\0';
solve();
return 0;
}