题目大意
给定两个字符串,s,t然后询问有多少个三元组满足 s[i~j] + t[1, k]并且i到j的长度大于k使得拼接的字符串是个回文字符串。
思路 & 代码
将s逆序,得到ss,对其和t求扩展KMP
得到ss[i…n-1] 和 t[0…m-1]的最长公共前缀。
然后其每个前缀的长度 * 以i结尾后缀回文的数量再求个和就ok
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MAXN = 1e6 + 100;
const int N = 26;
char s[MAXN], t[MAXN];
int extend[MAXN], nxt[MAXN];
int n,m;
void pre_EKMP(char x[], int m, int nxt[]) {
nxt[0] = m;
int j = 0;
while(j + 1 < m && x[j] == x[j+1]) j++;
nxt[1] = j;
int k = 1;
for(int i = 2; i < m; i++) {
int p = nxt[k] + k - 1;
int L = nxt[i - k];
if(i + L < p + 1) nxt[i] = L;
else {
j = max(0, p - i + 1);
while(i + j < m && x[i + j] == x[j]) j++;
nxt[i] = j;
k = i;
}
}
}
void EKMP(char x[], int m, char y[], int n, int nxt[], int extend[]) {
pre_EKMP(x, m, nxt);
int j = 0;
while(j < n && j < m && x[j] == y[j]) j++;
extend[0] = j;
int k = 0;
for(int i = 1; i < n; i++) {
int p = extend[k] + k - 1;
int L = nxt[i - k];
if(i + L < p + 1) extend[i] = L;
else {
j = max(0, p - i + 1);
while(i + j < n && j < m && y[i + j] == x[j]) j++;
extend[i] = j;
k = i;
}
}
}
struct PAM {
int ch[MAXN][N];// 往一个字符串左右添加一个字符对应的结点
int fail[MAXN]; // 对于后缀回文串来说,失配后能到达的下一个后缀回文
int cnt[MAXN]; // 本质不同的结点数量,最后需要count函数得到正确结果
int num[MAXN]; // num[i]表示以i结尾的后缀回文串的数量
int len[MAXN]; // 每个节点代表的回文串长度
int S[MAXN]; // 存字符串
int last; // 对应最长后缀回文节点
int n; // 字符集数量
int p; // 回文树结点数量
int newnode(int l) {
for(int i = 0; i < N; i++) ch[p][i] = 0;
cnt[p] = num[p] = 0;
len[p] = l;
return p++;
}
void init() {
p = 0;
newnode(0); // 偶结点
newnode(-1); // 奇结点
last = 0;
n = 0;
S[n] = -1;
fail[0] = 1;
}
int get_fail(int x) {
while(S[n - len[x] - 1] != S[n]) x = fail[x];
return x;
}
int add(int c) {
c-='a';
S[++n] = c;
int cur = get_fail(last);
if(!ch[cur][c]) {
int now = newnode(len[cur] + 2);// 新添加一个结点
fail[now] = ch[get_fail(fail[cur])][c];
ch[cur][c] = now;
num[now] = num[fail[now]] + 1;
}
last = ch[cur][c];
cnt[last]++;
return num[last];
}
void count() {
for(int i = p - 1; i >= 0; --i) {
cnt[fail[i]] += cnt[i];
}
}
} pam;
int num[MAXN];
int main() {
//freopen("/Users/maoxiangsun/MyRepertory/input.txt", "r", stdin);
scanf("%s%s",s,t);
n = (int)strlen(s);
m = (int)strlen(t);
reverse(s, s + n);
EKMP(t, m, s, n, nxt, extend);
// OK
pam.init();
for(int i = 0; i < n; i++) {
num[i] = pam.add(s[i]);
}
ll res = 0;
for(int i = 1; i < n ; i++) {
res += 1LL * num[i-1] * extend[i];
}
cout << res << endl;
return 0;
}