做法
exkmp + manacher
代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 1e6 + 10;
char s[maxn], ma[maxn * 2], t[maxn];
int mp[maxn * 2];
int q, nxt[maxn], extend[maxn], slen, tlen; // extend[i]:s[i...n] 与 t 的最长公共前缀长度
void getnxt(){
nxt[0] = tlen;//nxt[0]一定是T的长度
int now = 0;
while(t[now] == t[1 + now] && now + 1 < tlen) now++; //这就是从1开始暴力
nxt[1] = now;
int p0 = 1;
for(int i = 2; i < tlen; i++){
if(i + nxt[i - p0] < nxt[p0] + p0) nxt[i] = nxt[i - p0]; //第一种情况
else{ //第二种情况
int now = nxt[p0] + p0 - i;
now = max(now, 0);//这里是为了防止i>p的情况
while(t[now] == t[i + now] && i + now < tlen) now++; //暴力
nxt[i] = now;
p0 = i; //更新p0
}
}
}
void exkmp(){
getnxt();
int now = 0;
while(s[now] == t[now] && now < min(slen, tlen)) now++; //暴力
extend[0] = now;
int p0 = 0;
for(int i = 1; i < slen; i++){
if(i + nxt[i - p0] < extend[p0] + p0) extend[i] = nxt[i - p0]; //第一种情况
else{//第二种情况
int now = extend[p0] + p0 - i;
now = max(now, 0); //这里是为了防止i>p的情况
while(t[now] == s[i + now] && now < tlen && now + i < slen) now++; //暴力
extend[i] = now;
p0 = i; //更新p0
}
}
}
void Manacher(int len){
int l = 0;
ma[l++] = '$', ma[l++] = '#';
for(int i = 0; i < len; i++){
ma[l++] = s[i];
ma[l++] = '#';
}
ma[l] = 0;
int mx = 0, id = 0;
for(int i = 0; i < l; i++){
mp[i] = mx > i ? min(mp[2 * id - i], mx - i) : 1;
while(ma[i + mp[i]] == ma[i - mp[i]]) mp[i]++;
if(i + mp[i] > mx){
mx = i + mp[i];
id = i;
}
}
}
ll sum[maxn];
ll get_sum(int l, int r){
if(l > r) return 0;
if(l <= 0) return sum[r];
return sum[r] - sum[l - 1];
}
int main(){
scanf("%s %s", s, t);
slen = strlen(s), tlen = strlen(t);
Manacher(slen);///mp[i] - 1才是串的长度
reverse(s, s + slen);
exkmp();///获得了extend[]和nxt[]
reverse(extend, extend + slen);
sum[0] = extend[0];
for(int i = 1; i < slen; i++){
sum[i] = sum[i - 1] + extend[i];
}
ll ans = 0;
///下标:0 1 2 3 4
///原串:a b a b a
///下标:0 1 2 3 4 5 6 7 8 9 10 11
///新串:$ # a # b # a # b # a #
///半径:1 1 2 1 4 1 6 1 4 1 2 1 0
for(int i = 2; i < slen * 2 + 3; i++){///对于s的每个位置,看回文串
int cnt = mp[i] - 1;///回文串的长度
if(cnt == 0 || mp[i] == 0) continue;
int l, r, pos = (i - 2) / 2;
if(cnt & 1){///回文串的长度为奇数,如ababa
l = pos - (cnt - 1) / 2 - 1;
r = pos - 1;
}else{
l = pos - cnt / 2;
r = pos - 1;
}
ans += get_sum(l, r);
}
cout << ans << endl;
return 0;
}