这种需要匹配的问题都是将一个串建成自动机,别的串上去跑。
在跑的时候叠加答案,我们发现如果找上了一个结点,那这个结点的所有fail,作为这个子串的后缀,都应该被匹配。所以答案应该从fail链上累加下来。而每个点代表一个串集合。串的个数是maxlen-minlen+1。而minlen=fail的maxlen+1。所以直接减去fail的maxlen就能得到本点代表的字符串数。
先fail树,然后dfs得size,然后再从上到下传递信息得到key。最后run。
注意在run的时候不能直接使用now的len去减fail的len。因为字符往下走一位并不一定能涵盖下一点所代表的所有串集。
#include<bits/stdc++.h>
using namespace std;
#define int long long
struct node{
int ch[26],fail,len;
}t[500003];int last=1,cnt=1;
int size[500003],key[500003];
void insert(int x){
int p=last;int now=++cnt;last=cnt;t[now].len=t[p].len+1;size[now]=1;
for(;p&&!t[p].ch[x];p=t[p].fail)t[p].ch[x]=now;
if(!p)t[now].fail=1;
else{
int q=t[p].ch[x];if(t[q].len==t[p].len+1)t[now].fail=q;
else{
int tem=++cnt;t[tem]=t[q];t[tem].len=t[p].len+1;t[q].fail=t[now].fail=tem;
for(;p&&t[p].ch[x]==q;p=t[p].fail)t[p].ch[x]=tem;
}
}
}
int first[400003],nxt[400003],to[400003],tot;
void add(int a,int b){
nxt[++tot]=first[a];first[a]=tot;to[tot]=b;
}
void dfs(int u){
for(int i=first[u];i;i=nxt[i]){
int v=to[i];dfs(v);size[u]+=size[v];
}
}
void dfs2(int u){
key[u]=key[t[u].fail]+size[u]*(t[u].len-t[t[u].fail].len);
for(int i=first[u];i;i=nxt[i])dfs2(to[i]);
}
char s[200003],tt[200003];int len;
signed main(){
int ans=0;
scanf("%s",s+1);len=strlen(s+1);for(int i=1;i<=len;i++)insert(s[i]-'a');
scanf("%s",tt+1);len=strlen(tt+1);
for(int i=2;i<=cnt;i++)add(t[i].fail,i);dfs(1);dfs2(1);
int now=1,lenn=0;
// for(int i=1;i<=cnt;i++)cout<<t[i].fail<<" ";cout<<endl;
// for(int i=1;i<=cnt;i++)cout<<t[i].len<<" ";cout<<endl;
// for(int i=1;i<=cnt;i++)cout<<size[i]<<" ";cout<<endl;
// for(int i=1;i<=cnt;i++)cout<<key[i]<<" ";cout<<endl;
for(int i=1;i<=len;i++){
int x=tt[i]-'a';
if(t[now].ch[x]){
lenn++,now=t[now].ch[x];
}else{
while(now&&!t[now].ch[x])now=t[now].fail;
if(!now)lenn=0,now=1;
else lenn=t[now].len+1,now=t[now].ch[x];
}
ans+=key[t[now].fail]+size[now]*(lenn-t[t[now].fail].len);
}cout<<ans;
return 0;
}