题意
给定一个由a和b构成的串,求不连续回文子串的个数
输入
输入一个只由a和b组成的字符串
输出
输出不连续回文子串的个数
数据范围与规定
字符串长度<=100000
题解
首先先考虑回文串,很容易想到manacher,但是这里要求的是不连续的子串,How can you do it?????
设p[i]为 以i为对称轴,相等的两个字符共有多少个
例如ababa
那么p[3]=2
所以说以i为中心 回文串总个数有(2^p[i])-1这么多个
for一遍加起来再减掉manacher算出的连续回文串就好了
剩下的就是如何快速求出p数组了
按照manacher的思路,先加上分隔符
@a@b@a@b@a@(因为csdn里不能用#所以我就@了。。)
观察两个b,发现他们在原数组中位置分别为2和4
2+4=6
有点神奇哦。。好像加起来后这个下标是加了分隔符后第三个a的位置
那么可以大胆考虑,p[i]=(Σ[1<=j<=i-1]bool(str[j]==str[i-j]))+1>>1
里面是个卷积。可以fft优化
观察n的范围,10^5的数据,哇真的要用fft优化耶
这么神奇的推fft。。我%%%%
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<algorithm>
#include<cmath>
using namespace std;
typedef long long LL;
const LL mod=1000000007;
const double PI=acos(-1.0);
const int MAXN=263000;
struct Complex
{
double r,i;//real imag
Complex() {}
Complex(double _r,double _i){r=_r;i=_i;}
friend Complex operator + (const Complex &x,const Complex &y){return Complex(x.r+y.r,x.i+y.i);}
friend Complex operator - (const Complex &x,const Complex &y){return Complex(x.r-y.r,x.i-y.i);}
friend Complex operator * (const Complex &x,const Complex &y){return Complex(x.r*y.r-x.i*y.i,x.r*y.i+x.i*y.r);}
}a[MAXN],b[MAXN];
int R[MAXN],L,m,p[MAXN];
LL ans,tmp[MAXN],list[MAXN];
char s[MAXN],now[MAXN];
void fft(Complex *y,int len,int on)
{
for(int i=0;i<len;i++)if(i<R[i])swap(y[i],y[R[i]]);
for(int i=1;i<len;i<<=1)//枚举需要合并的长度 合并后的长度就成了i*2对吧。所以无需枚举至len
{
Complex wn(cos(PI/i),sin(on*PI/i));//无需乘2,因为合并后长度i*2,用到的单位复数根只有i
for(int j=0;j<len;j+=(i<<1))//被分成了L/(i<<1)段序列
{
Complex w(1,0);//注意一点,w是在for循环执行完毕后才累乘,因为我们还有w^0对吧
for(int k=0;k<i;k++,w=w*wn)//枚举前半部分,后半部分加上一个i就可以了嘛
{
Complex u=y[j+k];//j+k即是前半部分
Complex v=w*y[j+k+i];//j+k+i即是后半部分
y[j+k]=u+v;
y[j+k+i]=u-v;
}
}
}
if(on==-1)for(int i=0;i<len;i++)y[i].r/=len;//IFFT 每个数都要/len
}
int manacher()
{
LL ss=0;
for(int i=1;i<=m;i++)now[2*i-1]='#',now[2*i]=s[i-1];
m=2*m+1;
now[m]='#';
int k=0,r=0;
for(int i=1;i<=m;i++)
{
int j=k-(i-k);//i以k为中心的对称点
if(i<=r)
{
if(p[j]<r-i+1)p[i]=p[j];
else p[i]=r-i+1;
}
else p[i]=1;
while(i-p[i]>=1 && i+p[i]<=m && now[i-p[i]]==now[i+p[i]])p[i]++;
if(i+p[i]-1>r){k=i;r=i+p[i]-1;}
ss=(ss+p[i]/2)%mod;
}
return ss;
}
void gettmp()
{
tmp[0]=1;
for(int i=1;i<MAXN;i++)tmp[i]=(tmp[i-1]*2)%mod;
}
int main()
{
scanf("%s",s);
m=strlen(s);
L=0;int len;
for(len=1;len<=m*2;len<<=1)L++;
for(int i=0;i<len;i++)R[i]=(R[i>>1]>>1)|(i&1)<<(L-1);
for(int i=0;i<m;i++)if(s[i]=='a')a[i]=Complex(1,0);
fft(a,len,1);
for(int i=0;i<len;i++)b[i]=b[i]+(a[i]*a[i]);
for(int i=0;i<len;i++)
{
if(s[i]=='b')a[i]=Complex(1,0);
else a[i]=Complex(0,0);
}
fft(a,len,1);
for(int i=0;i<len;i++)b[i]=b[i]+(a[i]*a[i]);
fft(b,len,-1);
LL del=manacher();
for(int i=0;i<len;i++)list[i]=(LL)(b[i].r+0.5);
ans=0;gettmp();
for(int i=1;i<len;i++)ans=(ans+tmp[list[i]+1>>1]-1)%mod;
printf("%lld\n",(ans-del+mod+1)%mod);
return 0;
}