建立一个广义后缀自动机然后统计同一节点下 子串个数 * a串的节点出现次数 * b串的节点出现次数,利用 D F S DFS DFS 序遍历所有节点,在返回时将子节点的大小累计至父节点即可。
#include<bits/stdc++.h>
#define endl '\n'
using namespace std;
typedef long long LL;
const int maxn = 8e5 + 5;
char s[maxn];
int sz, last, cnt; // sz = size
int head[maxn], size[maxn][2];
LL ans;
struct state{
int len, link;
int next[26];
}st[maxn];
struct node{
int v, next;
}edge[maxn];
inline void add(int u, int v){
edge[cnt].next = head[u];
edge[cnt].v = v;
head[u] = cnt++;
}
inline void init(){
memset(head, -1, sizeof(head));
st[0].len = 0;
st[0].link = -1;
sz = 0;
last = 0;
cnt = 0;
}
inline void extend(char ch, int no){ // 广义 SAM
int c = ch - 'a';
if(st[last].next[c]){
int p = last, x = st[last].next[c];
if(st[x].len == st[p].len + 1)
last = x;
else{
int y = ++sz;
st[y] = st[x];
st[y].len = st[p].len + 1;
st[x].link = y;
while(~p && st[p].next[c] == x){
st[p].next[c] = y;
p = st[p].link;
}
last = y;
}
size[last][no] = 1;
return;
}
int now = ++sz;
st[now].len = st[last].len + 1;
int p = last;
while(~p && !st[p].next[c]){
st[p].next[c] = now;
p = st[p].link;
}
if(p == -1)
st[now].link = 0;
else{
int q = st[p].next[c];
if(st[p].len + 1 == st[q].len)
st[now].link = q;
else{
int clone = ++sz;
st[clone] = st[q];
st[clone].len = st[p].len + 1;
st[q].link = st[now].link = clone;
while(~p && st[p].next[c] == q){
st[p].next[c] = clone;
p = st[p].link;
}
}
}
last = now;
size[last][no] = 1;
}
inline void dfs(int u){
for(int k = head[u]; ~k; k = edge[k].next){
dfs(edge[k].v);
size[u][0] += size[edge[k].v][0];
size[u][1] += size[edge[k].v][1];
}
ans += (LL)(st[u].len - st[st[u].link].len) * size[u][0] * size[u][1];
}
int main(){
cin.tie(0);
cout.tie(0);
ios::sync_with_stdio(false);
int n = 2;
init();
while(n--){
cin >> s;
last = 0;
for(int i = 0; s[i]; i++) extend(s[i], n);
}
for(int i = 1; i <= sz; i++) add(st[i].link, i);
dfs(0);
cout << ans;
}