Problem M. Mediocre String Problem(马拉车+拓展KMP)
题意:给一个S串一个T串, 问有多少个F(i, j, k),F(i, j, k) 的定义是S串选个下标i~j的子串,与T串中1~k的子串合并, 合并后结果为回文串且 k < (j-i+1), 1 <= i <= j <= lens, 1 <= k <= lent。
思路:
将S串翻转后进行操作,先用马拉车将S串每个字符为结尾时的回文个数计算出来也就是cnt[i],计算过程是利用差分。
再将翻转后的S与T进行拓展KMP求出LCP,最后直接累加乘积就行。
给个例子:
S: bbabbaa
T: bbaab
S翻转后: aabbabb
T不变:bbaab
拓kmp求完后,S串对应的extend[i] = {0, 0, 3, 1, 0, 2, 1}(S[i...lens-1]与T[0...lent-1]的最长公共前缀)
你会发现
0031021
aabbabb
bbaab
S的bba与T的bba是构成回文串的(S翻转了,按原来的看, abb+bba就是回文,然后要看这俩串之间有多少个回文 也就是对应的cnt[i-1],i为2第一个b的为止,cnt[i-1]就是第二个a为结尾的回文个数。
Tip: 我的cnt[]是1~lens, extend[]是0~lens-1
为什么要相乘,展开一下就知道了
{
abb (a) bba, abb(aa)bba, bb(a)bb, bb(aa)bb, b(a)b, b(aa)b
这不就是(lcp*S串公共前缀前面第一个字符为结尾时的回文个数)嘛
所以当遍历到i = 2时已经有6个了,继续按这个思路走,遍历完S串后,最后答案为13。
}
#include <bits/stdc++.h>
#define mem(a, b) memset(a, b, sizeof(a))
#define up(i, a, b)for(int i = a;i <= b; i++)
#define down(i, a, b)for(int i =a;i >= b; i--)
#define inf 0x3f3f3f3f
#define int long long
using namespace std;
const int maxn = 1e6 + 700;
char s[maxn], t[maxn];
char tmp[maxn<<1], len[maxn<<1], a[maxn<<1], b[maxn<<1];
int Len[maxn << 1];
int cha[maxn<<1], cnt[maxn<<1];
int init(char *st){
int i, len = strlen(st);
tmp[0] = '@';
for(i = 1;i <= 2*len; i+=2){
tmp[i] = '#';tmp[i+1] = st[i/2];
}
tmp[2*len+1] = '#';
tmp[2*len+2] = '$';
tmp[2*len+3] = 0;
return 2*len+1;
}
void manacher(char *st, int len){
int mx = 0, ans = 0, po = 0;
for(int i = 1;i <= len; i++){
if(mx > i)Len[i] = min(mx - i, Len[2*po-i]);
else Len[i] = 1;
while(st[i-Len[i]] == st[i+Len[i]])Len[i]++;
if(Len[i] + i > mx){mx = Len[i] + i; po = i;}
ans = max(ans, Len[i]);
}
int now = 0;
for(int i = 1;i <= len; i++){
cha[i]++;
cha[i+Len[i]]--;
}
for(int i = 1;i <= len; i++){
now += cha[i];
if(i%2 == 0)cnt[i/2] = now;
}
}
int nxt[maxn<<1], extend[maxn<<1];
int La, Lb;
void pre_ekmp(){
nxt[0] = Lb;
int j = 0;
while(j+1 < Lb && b[j] == b[j+1])j++;
nxt[1] = j;
int k = 1;
for(int i = 2;i < Lb; i++){
int p = nxt[k] +k - 1;
int L = nxt[i-k];
if(i+L < p+1)nxt[i] = L;
else{
j = max(0ll, p-i+1);
while(i + j < Lb && b[i+j] == b[j])j++;
nxt[i] = j;
k = i;
}
}
}
void ekmp(){
pre_ekmp();
int j = 0;
while(j < La && j < Lb && b[j] == a[j])j++;
extend[0] = j;
int k = 0;
for(int i = 1;i < La; i++){
int p = extend[k] + k - 1;
int L = nxt[i - k];
if(i + L < p + 1) extend[i] = L;
else{
j = max(0ll, p-i+1);
while(i + j < La && j < Lb && a[i+j] == b[j])j++;
extend[i] = j;
k = i;
}
}
}
signed main()
{
scanf("%s", a);
scanf("%s", b);
La = strlen(a);
reverse(a, a+La);
int len = init(a);
manacher(tmp, len);
Lb = strlen(b);
ekmp();
// up(i, 0, La-1)cout << extend[i] << endl;
int ans = 0;
// up(i, 1, La)cout << cnt[i] << endl;
up(i, 1, La){
ans += extend[i-1]*cnt[i-1];
}
printf("%lld\n", ans);
return 0;
}