Problem M. Mediocre String Problem
给定两个串 s , t s, t s,t,要求有多少不同的三元组 ( i , j , k ) (i, j, k) (i,j,k),满足:
- 1 ≤ i ≤ j ≤ ∣ s ∣ 1 \le i \le j \le \mid s \mid 1≤i≤j≤∣s∣。
- 1 ≤ k ≤ ∣ t ∣ 1 \le k \le \mid t \mid 1≤k≤∣t∣。
- j − i + 1 ≥ k j - i + 1 \ge k j−i+1≥k。
- 且 s [ i , j ] + t [ 1 , k ] s[i, j] + t[1, k] s[i,j]+t[1,k]是一个回文串。
设 s [ i , j ] + t [ 1 , k ] s[i, j] + t[1, k] s[i,j]+t[1,k]连接而成的串为 A + T + B A + T + B A+T+B, s [ i , j ] = A + T , t [ 1 , k ] = B s[i, j] = A + T, t[1, k] = B s[i,j]=A+T,t[1,k]=B,则一定满足 r e v ( A ) = B rev(A) = B rev(A)=B, T T T其本身是一个非空回文串。
我们考虑把 s s s串,翻转,那么问题就变成了:
- 1 ≤ i ≤ j ≤ ∣ s ∣ 1 \le i \le j \le \mid s \mid 1≤i≤j≤∣s∣。
- 1 ≤ k ≤ ∣ t ∣ 1 \le k \le \mid t \mid 1≤k≤∣t∣。
- j − i + 1 ≥ k j - i + 1 \ge k j−i+1≥k。
- s [ i , j ] + t [ 1 , k ] = T + A + B s[i, j] + t[1, k] = T + A + B s[i,j]+t[1,k]=T+A+B,其中 T + A = s [ i , j ] , B = t [ 1 , k ] T + A = s[i, j], B = t[1, k] T+A=s[i,j],B=t[1,k],满足 A = B A = B A=B,且 T T T是一个非空回文串。
可以考虑用 e x k m p exkmp exkmp,得到 s s s的所有后缀与 t t t中前缀的最长匹配长度,然后枚举点 i i i,则答案为 ∑ i = 1 m p a l i n d r o m e _ s u m [ i − 1 ] × l e n [ i ] \sum\limits_{i = 1} ^{m} palindrome\_sum[i - 1] \times len[i] i=1∑mpalindrome_sum[i−1]×len[i]。
palindrome_sum[i] 表示以 i 结尾的回文串有多少个,可以用回文树,简单求得。
所以只要套上 E X K M P , P A M EXKMP, PAM EXKMP,PAM即可。
#include <bits/stdc++.h>
using namespace std;
namespace PAM {
const int N = 1e6 + 10;
int sz, tot, last, cnt[N], nex[N][26], len[N], fail[N], dep[N], palindrome_sum[N];
char s[N];
int node(int l) {
++sz, len[sz] = l, fail[sz] = cnt[sz] = 0;
return sz;
}
void init() {
sz = -1, last = 0, s[tot = 0] = '$';
node(0), node(-1), fail[0] = 1;
}
int getFail(int rt) {
while (s[tot - len[rt] - 1] != s[tot]) {
rt = fail[rt];
}
return rt;
}
void insert(char c, int id) {
s[++tot] = c;
int cur = getFail(last);
if (!nex[cur][c - 'a']) {
int x = node(len[cur] + 2);
fail[x] = nex[getFail(fail[cur])][c - 'a'];
nex[cur][c - 'a'] = x;
}
last = nex[cur][c - 'a'];
cnt[last]++, dep[last] = dep[fail[last]] + 1;
palindrome_sum[id] = dep[last];
}
}
const int N = 2e6 + 10;
char str1[N], str2[N];
int a[N], n, m;
int main() {
// freopen("in.txt", "r", stdin);
// freopen("out.txt", "w", stdout);
scanf("%s %s", str2 + 1, str1 + 1);
n = strlen(str1 + 1), m = strlen(str2 + 1);
reverse(str2 + 1, str2 + m + 1);
for (int i = 1; i <= m; i++) {
str1[i + n] = str2[i];
}
for (int i = 2, l = 1, r = 1; i <= n + m; i++) {
if (i <= r && a[i - l + 1] < r - i + 1) {
a[i] = a[i - l + 1];
}
else {
a[i] = max(0, r - i + 1);
while (i + a[i] <= n + m && str1[a[i] + 1] == str1[i + a[i]]) {
a[i]++;
}
}
if (i + a[i] - 1 > r) {
l = i, r = i + a[i] - 1;
}
}
PAM::init();
for (int i = 1; i <= m; i++) {
PAM::insert(str2[i], i);
}
long long ans = 0;
for (int i = n + 1; i <= n + m; i++) {
ans += 1ll * PAM::palindrome_sum[i - n - 1] * min({n, n + m - i + 1, a[i]});
}
printf("%lld\n", ans);
return 0;
}