同步于:https://www.luogu.com.cn/blog/OUYE2020/solution-p7469
题解
考虑字符串
T
T
T 中的本质不同的子串最多
n
2
n^2
n2 个,暴力枚举判断至少是
O
(
n
3
)
O(n^3)
O(n3) 的,用
hash
\text{hash}
hash 可以优化到
O
(
n
2
)
O(n^2)
O(n2) 。然而为了保证正确性,我们不用
hash
\text{hash}
hash (其实我做字符串题就只用过一次这玩意 。
那么为了去重,最好的方法是对 T T T 串建立一个后缀自动机,按边遍历整个后缀自动机就可以像走 trie \text{trie} trie 树一样得到每个子串的信息。由于后缀自动机上最多 2 n 2n 2n 个点,每个点会遍历不超过 n n n 次(其实比 n n n 小得多),所以遍历一遍均摊复杂度 O ( n 2 ) O(n^2) O(n2) ,在 n > 26 n>26 n>26 的情况下严格小于 O ( n 2 ) O(n^2) O(n2) 。
怎么拿 T T T 中的子串和 S S S 中的子序列匹配呢?我们贪心地记录 f i , c f_{i,c} fi,c 为第 i i i 个字符往后第一次出现字符 c c c 的位置,把 i i i 和 f i , c f_{i,c} fi,c 之间连边,最终可得到一颗最多 n ∗ 26 n*26 n∗26 条边的 trie \text{trie} trie 树。根据贪心思路,这棵树必定记录了 S S S 的所有子序列的信息,这颗 trie \text{trie} trie 树就叫 S S S 的子序列自动机。
然后我们只需要同步遍历两个自动机,求出所有重合的节点数就是答案。这个复杂度为两个自动机遍历次数取较小值,所以总复杂度小于 O ( n 2 ) O(n^2) O(n2) 。
代码
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<cmath>
#include<vector>
#include<queue>
#include<stack>
#define ll long long
#define MAXN 3005
#define uns unsigned
#define INF 0x3f3f3f3f
using namespace std;
inline ll read(){
ll x=0;bool f=1;char s=getchar();
while((s<'0'||s>'9')&&s>0){if(s=='-')f^=1;s=getchar();}
while(s>='0'&&s<='9')x=(x<<1)+(x<<3)+s-'0',s=getchar();
return f?x:-x;
}
struct SAM{
int ch[26],len,fa;
SAM(){memset(ch,0,sizeof(ch)),len=fa=0;}
}sam[MAXN<<1];
int las=1,tot=1;
inline void samadd(int c){
int p=las,np=las=++tot;sam[np].len=sam[p].len+1;
for(;p&&sam[p].ch[c]==0;p=sam[p].fa)sam[p].ch[c]=np;
if(!p)sam[np].fa=1;
else{int q=sam[p].ch[c];
if(sam[q].len==sam[p].len+1)sam[np].fa=q;
else{
int nq=++tot;sam[nq]=sam[q],sam[nq].len=sam[p].len+1,sam[q].fa=sam[np].fa=nq;
for(;p&&sam[p].ch[c]==q;p=sam[p].fa)sam[p].ch[c]=nq;
}
}
}
int n,ans;
char a[MAXN],b[MAXN];
int tr[MAXN][26];
inline void dfs(int x,int y){
if(!x||!y)return;
if(x>1)ans++;
for(int i=0;i<26;i++)
dfs(sam[x].ch[i],tr[y][i]);
}
signed main()
{
// freopen("block.in","r",stdin);
// freopen("block.out","w",stdout);
n=read();
scanf("%s\n%s",a+2,b+1);
for(int i=n;i>0;i--){
for(int j=0;j<26;j++)tr[i][j]=tr[i+1][j];
tr[i][a[i+1]-'a']=i+1;
}
for(int i=1;i<=n;i++)samadd(b[i]-'a');
dfs(1,1);
printf("%d\n",ans);
return 0;
}