先把 s 串倒过来。
然后 exkmp 搞一下,匹配每个后缀和 t 的前缀。
然后用 回文自动机求出来 以 i 这个字符结尾有多少个回文串。
#include<bits/stdc++.h>
using namespace std;
const int N = 2e6+100;
char s1[N],s[N],s2[N],t[N];
int lens,lent,Next[N],ext[N];
int sz,f1[N],ch[N][26],l[N],num[N],len[N];
int find(int x, int y){
return s[y] == s[y-len[x]-1]?x:find(f1[x],y);
}
void pam(){
int now = 0;
sz = 1;
f1[0] = f1[1] = 1;
len[0] = 0; len[1] = -1;
for (int i = 1; i <= lens; ++i){
now = find(now,i);
if (!ch[now][s[i]-'a']){
len[++sz] = len[now]+2;
f1[sz] = ch[find(f1[now],i)][s[i]-'a'];
ch[now][s[i]-'a'] = sz;
num[sz] = num[f1[sz]] + 1;
}
now = ch[now][s[i]-'a'];
l[i] = num[now];
}
}
void getnext(){
int po,i = 0,j;
Next[0] = lent;
while(t[i] == t[i+1] && i+1 < lent) i++;
Next[1] = i;
po = 1;
for (i = 2; i < lent; ++i){
if (Next[i - po] + i < Next[po] + po) Next[i] = Next[i - po];
else {
j = Next[po] + po - i;
if (j < 0) j = 0;
while(i + j < lent && t[i+j] == t[j]) j++;
Next[i] = j;
po =i;
}
}
}
void exkmp(){
int i = 0, j ,po;
getnext();
while(i < lens && i < lent && s[i] == t[i]) i++;
ext[0] = i;
po = 0;
for (i = 1; i < lens; ++i){
if (Next[i - po] + i < ext[po] + po) ext[i] = Next[i-po];
else{
j = ext[po] + po - i;
if (j < 0) j = 0;
while(i +j < lens && j < lent && s[i+j] == t[j]) ++j;
ext[i] = j;
po = i;
}
}
}
int main(){
scanf("%s",s1);
scanf("%s",t);
lens = strlen(s1); lent = strlen(t);
for (int i = 0; i < lens; ++i)
s[i] = s1[lens-i-1];
exkmp();
for (int i = lens; i > 0; --i)
s[i] = s[i-1];
s[0] = '&';
pam();
long long ans = 0;
for (int i = 1; i < lens; ++i)
ans = ans + (1ll)*ext[i]*l[i];
printf("%lld\n",ans);
return 0;
}
/*
ababa
aba
aabbaa
aabb
*/