首先把字符串用#间隔开
总数-连续的字符串
连续的用manacher求
总数如何求?
f[i]表示以i为中心有多少对对称的字符(不包含#,但包含本身)
ans=∑2^f[i]-n
原字符串的位置i对应新字符串的位置2*i
f[i]=∑str[j]*str[i-j]
第一次a=1,b=0
第二次a=0,b=1
这样算出来是对称的个数,(f[i]+1)/2是对数
总数-连续的字符串
连续的用manacher求
总数如何求?
f[i]表示以i为中心有多少对对称的字符(不包含#,但包含本身)
ans=∑2^f[i]-n
原字符串的位置i对应新字符串的位置2*i
f[i]=∑str[j]*str[i-j]
第一次a=1,b=0
第二次a=0,b=1
这样算出来是对称的个数,(f[i]+1)/2是对数
最后直接计算就可以了
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<cmath>
#include<algorithm>
#include<iostream>
#define maxn 500010
#define pi acos(-1)
#define mod 1000000007
using namespace std;
struct yts
{
double r,i;
yts operator+(yts x) {yts ans;ans.r=r+x.r;ans.i=i+x.i;return ans;}
yts operator-(yts x) {yts ans;ans.r=r-x.r;ans.i=i-x.i;return ans;}
yts operator*(yts x) {yts ans;ans.r=r*x.r-i*x.i;ans.i=r*x.i+i*x.r;return ans;}
}a[maxn],b[maxn],temp[maxn];
int n,m,digit;
long long ans=0;
long long f[maxn];
char s[maxn],s1[maxn];
int p[maxn];
long long Pow[maxn];
void FFT(yts x[],int n,int type)
{
if (n==1) return;
for (int i=0;i<n;i+=2) temp[i>>1]=x[i],temp[n+i>>1]=x[i+1];
memcpy(x,temp,sizeof(yts)*n);
yts *l=x,*r=x+(n>>1);
FFT(l,n>>1,type);FFT(r,n>>1,type);
yts root,w;
root.r=cos(2*pi*type/n);root.i=sin(2*pi*type/n);
w.r=1;w.i=0;
for (int i=0;i<(n>>1);i++,w=w*root)
temp[i]=l[i]+w*r[i],temp[(n>>1)+i]=l[i]-w*r[i];
memcpy(x,temp,sizeof(yts)*n);
}
long long manacher()
{
for (int i=0;i<n;i++) s1[2*i+1]='#',s1[2*i+2]=s[i];
s1[0]='-';s1[2*n+1]='#';s1[2*n+2]='+';n<<=1;
int id=0,mx=0;
long long ans=0;
for (int i=1;i<=n;i++)
{
if (mx>i) p[i]=min(mx-i,p[2*id-i]);
else p[i]=1;
while (s1[i+p[i]]==s1[i-p[i]]) p[i]++;
if (i+p[i]>mx) id=i,mx=i+p[i];
ans=(ans+p[i]/2)%mod;
}
return ans;
}
int main()
{
scanf("%s",s);
n=strlen(s);
for (digit=1;digit<(n<<1);digit<<=1);
for (int i=0;i<n;i++) if (s[i]=='a') a[i].r=1;
FFT(a,digit,1);
for (int i=0;i<digit;i++) b[i]=b[i]+a[i]*a[i];
memset(a,0,sizeof(a));
for (int i=0;i<n;i++) if (s[i]=='b') a[i].r=1;
FFT(a,digit,1);
for (int i=0;i<digit;i++) b[i]=b[i]+a[i]*a[i];
FFT(b,digit,-1);
for (int i=0;i<digit;i++) f[i]=(long long)(b[i].r+0.5)/digit;
Pow[0]=1;
for (int i=1;i<=n;i++) Pow[i]=Pow[i-1]*2%mod;
for (int i=0;i<digit;i++) ans=(ans+Pow[(f[i]+1)>>1]-1)%mod;
ans=(ans-manacher()+mod)%mod;
printf("%lld\n",ans);
return 0;
}