对两个字符串都构建后缀自动机
对第一个字符串,求F[i]表示从rt到i节点有多少不同的路径(即求不同的子串的个数);
对第二个字符串,求G[i]表示从节点i开始往下走能走出多少不同的路径(即求不同的子串的个数);
这两个动态规划都可以用拓扑排序和ch[][]数组转移完成,第二个dp可能需要将ch[][]里面的每一条边反向。
枚举第一个后缀自动机里的节点u,若该节点的出边j不存在,则ans += F[u] + G[ ch[rt][j] ]
思路是对于相等的a+b和a'+b',取在第一个字符串里字串较长的那个。
是在学校OJ上面提交的,一定要用unsigned long long才放过去
#include <iostream>
#include <cstdio>
#include <cstring>
#include <queue>
#include <vector>
#define N 400050
using namespace std;
typedef long long LL;
vector<int> e[N];
char s[N];
int n;
struct SAM{
int ch[N][27] , len[N] , link[N] , rd[N];
LL F[N];
int cnt=1 , las=1 , rt = 1;
void add(int pos) {
int x = s[pos] - 'a' + 1 , np = ++cnt , p = las;
las = np; len[np] = len[p] + 1;
while (p && !ch[p][x]) ch[p][x] = np , p = link[p];
if (!p)
link[np] = rt;
else {
int q = ch[p][x];
if (len[q] == len[p] + 1)
link[np] = q;
else {
int nq = ++cnt;
len[nq] = len[p] + 1;
memcpy(ch[nq],ch[q],sizeof(ch[q]));
link[nq] = link[q];
link[q] = link[np] = nq;
while (p && ch[p][x] == q) ch[p][x] = nq , p = link[p];
}
}
return ;
}
}A,B;
int main() {
scanf("%s",s+1); n = strlen(s+1);
for (int i=1;i<=n;i++) A.add(i);
scanf("%s",s+1); n = strlen(s+1);
for (int i=1;i<=n;i++) B.add(i);
A.F[1] = B.F[1] = 1LL;
for (int i=1;i<=A.cnt;i++) A.F[i] = 1;
for (int i=1;i<=B.cnt;i++) B.F[i] = 1;
queue<int> q;
for (int i=1;i<=B.cnt;i++)
for (int j=1;j<=26;j++)
if (B.ch[i][j]) e[ B.ch[i][j] ].push_back(i);
for (int i=1;i<=B.cnt;i++)
for (int j=0;j<(int)e[i].size();j++) B.rd[ e[i][j] ]++;
for (int i=1;i<=B.cnt;i++) if (!B.rd[i]) q.push(i);
while (!q.empty()) {
int u = q.front(); q.pop();
for (int i=0;i<(int)e[u].size();i++)
B.F[ e[u][i] ] += B.F[u];
for (int i=0;i<(int)e[u].size();i++)
if ( --B.rd[ e[u][i] ] == 0) q.push(e[u][i]);
}
memset(A.F,0,sizeof(A.F)); A.F[1] = 1LL;
for (int i=1;i<=A.cnt;i++)
for (int j=1;j<=26;j++)
if (A.ch[i][j]) A.rd[ A.ch[i][j] ]++;
for (int i=1;i<=A.cnt;i++) if (!A.rd[i]) q.push(i);
while (!q.empty()) {
int u = q.front(); q.pop();
for (int i=1;i<=26;i++) if (A.ch[u][i])
A.F[ A.ch[u][i] ] += A.F[u];
for (int i=1;i<=26;i++)
if ( --A.rd[ A.ch[u][i] ] == 0) q.push(A.ch[u][i]);
}
unsigned long long ans = 0LL;
for (int i=1;i<=A.cnt;i++) {
for (int j=1;j<=26;j++) if (!A.ch[i][j])
ans += 1LL * A.F[i] * (B.F[ B.ch[1][j] ]);
ans += 1LL * A.F[i];
}
cout << ans << endl;
return 0;
}