题链:https://nanti.jisuanke.com/t/A2150
题意:给出两个串S,T。求S的子串接上T的前缀子串(S子串的长度要大于T前缀子串的长度。)所能组成的回文串的个数。
思路:我们可以把S的子串s分为两部分s1和s2。那么s2一定得是回文串且长度要大于1,s1与T的子串t又组成回文串。
马拉车求出S(假设为cbaba)的每个回文子串。
对于回文子串“aba”,可选的s2有
1.“b”(对应的s1可以为"a","ba","cba",那么我们需要的T的子串前缀为"a","ab","abc")。
2."aba"(对应的s1可以为"b","cb",那么我们需要的T的子串前缀为"b","bc")。
我们可以发现我们需要知道的是s[i],s[i-1],s[i-2],...,s[0]与T的最长公共前缀(LCP)的长度。那我们反转一下S用扩展KMP求一下extend,再把extend反转一下就是我们需要的。
现在对于S的回文子串,它对答案的贡献为:
1.当长度为奇数,我们假设回文子串的长度为len,在中间的字符的下标为pos,那么可选的s2为(S[pos,pos],S[pos-1,pos+1],S[pos-2,pos+2],...S[pos-len/2,pos+len/2]);
对应的对答案的贡献贡献为(extend[pos-1],extend[pos-2],extend[pos-3],...,extend[0])(用前缀和O(1)求即可。)
2.当长度为偶数,我们假设回文子串的长度为len,在中间靠左的字符(例如"abba"中第一个b的下标)的下标为pos,那么可选的s2为(S[pos,pos+1],S[pos-1,pos+2],S[pos-2,pos+3],...S[pos-len/2-1,pos+len/2]);
对应的对答案的贡献贡献为(extend[pos-1],extend[pos-2],extend[pos-3],...extend[0])(用前缀和O(1)求即可。)
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int N=2e6+10;
char str1[N],str2[N],t[N];
int p[N];
void Manacher()
{
int k=0;
str2[k++]='!';
int len=strlen(str1);
for(int i=0;i<len;i++)
str2[k++]='#',str2[k++]=str1[i];
str2[k++]='#';
str2[k]=0;
int sum=0;
p[0]=1;
int id=0,mx=0;
int mid=0;
for(int i=1;i<k;i++)
{
if(i<mx) p[i]=min(mx-i,p[2*id-i]);
else p[i]=1;
while(str2[i-p[i]]==str2[i+p[i]]) p[i]++;
if(i+p[i]>mx)
{
mx=i+p[i];
id=i;
}
}
}
int extend[N],nex[N];
void GetNext(char* T, int& m, int nex[])
{
int a=0,p=0;
nex[0]=m;
for (int i=1;i<m;i++)
{
if (i>=p||i+nex[i-a]>=p)
{
if (i>=p) p=i;
while (p<m&&T[p]==T[p - i]) p++;
nex[i]=p-i;
a=i;
}
else
nex[i]=nex[i-a];
}
}
void GetExtend(char* S, int n, char* T, int& m, int extend[], int nex[])
{
int a=0,p=0;
GetNext(T, m, nex);
for (int i=0;i<n;i++)
{
if (i>=p||i+nex[i-a]>=p)
{
if (i>=p) p=i;
while (p<n&&p-i<m&&S[p]==T[p-i]) p++;
extend[i]=p-i;
a=i;
}
else
extend[i]=nex[i-a];
}
}
ll pre[N];
ll ask(int l,int r){
if(l>r) return 0;
if(l<=0) return pre[r];
return pre[r]-pre[l-1];
}
int main(){
scanf("%s%s",str1,t);
int lens=strlen(str1),lent=strlen(t);
Manacher();
reverse(str1,str1+lens);
GetExtend(str1,lens,t,lent,extend,nex);
reverse(extend,extend+lens);
pre[0]=extend[0];
for(int i=1;i<lens;i++)
pre[i]=pre[i-1]+extend[i];
ll ans=0;
for(int i=2;i<lens*2+3;i++){
int len=p[i]-1;
if(len<=0) continue;
if(len&1){//子串长度为奇数
//中间字符的下标
int pos=(i-p[i])/2+len/2;
//前缀和的右边界
int r=pos-1;
//前缀和的左边界
int l=(i-p[i])/2-1;
ans+=ask(l,r);
}else{//子串长度为奇数
//中间靠左字符的下标
int pos=(i-p[i])/2+len/2-1;
int r=pos-1;
int l=(i-p[i])/2-1;
ans+=ask(l,r);
}
}
printf("%lld\n",ans);
return 0;
}