求两个字符串中各取出一个串使他们相同的方案数
第一个串建SAM,第二个串在上面跑匹配。
这里需要从前面遍历到后面
用size表示它的个数,求每个点的贡献,这个点已经匹配上了,那么它的父节点也是可以的,所有结果是f[fa[p]]+size[p]∗(mx[p]−mx[fa[p]])
#include <cstdio>
#include <cstring>
#define N 200010
#define ll long long
int root,last,cnt=0,size[N<<1],mx[N<<1],son[N<<1][26],c[N<<1],a[N<<1],fa[N<<1];
ll ans=0,f[N<<1];
char str1[N],str2[N];
inline void ins(int ch){
int p=last,np=++cnt;mx[np]=mx[p]+1;last=np;size[np]=1;
while(p && !son[p][ch]) son[p][ch]=np,p=fa[p];
if(!p) fa[np]=root;
else{
int q=son[p][ch];
if(mx[q]==mx[p]+1) fa[np]=q;
else{
int nq=++cnt;mx[nq]=mx[p]+1;
memcpy(son[nq],son[q],sizeof(son[nq]));
fa[nq]=fa[q];fa[q]=fa[np]=nq;
while(son[p][ch]==q) son[p][ch]=nq,p=fa[p];
}
}
}
int main(){
scanf("%s%s",str1+1,str2+1);
int n=strlen(str1+1);root=last=++cnt;
for(int i=1;i<=n;i++) ins(str1[i]-'a');
for(int i=1;i<=cnt;i++) c[mx[i]]++;
for(int i=1;i<=cnt;i++) c[i]+=c[i-1];
for(int i=cnt;i>=1;i--) a[c[mx[i]]--]=i;
for(int i=cnt;i>=1;i--)
size[fa[a[i]]]+=size[a[i]];
for(int i=1;i<=cnt;i++) f[a[i]]=f[fa[a[i]]]+(ll)size[a[i]]*(mx[a[i]]-mx[fa[a[i]]]);
int p=root,len=0;n=strlen(str2+1);
for(int i=1;i<=n;i++){
int ch=str2[i]-'a';
if(son[p][ch]) len++,p=son[p][ch];
else{
while(p&&!son[p][ch]) p=fa[p];
if(!p) len=0,p=root;
else len=mx[p]+1,p=son[p][ch];
}ans+=f[fa[p]]+(ll)size[p]*(len-mx[fa[p]]);
}printf("%lld",ans);
return 0;
}