分析:
直接计算不连续回文子序列有点困难
我们可以先计算出回文子序列的数量,答案就是“回文子序列的数量-连续回文子序列的数量”
而连续回文子序列实际上就是回文子串,用manacher算法O(n)计算即可
而回文子序列的数量就有一点难度了
设
f[i]
f
[
i
]
表示以
i
i
为中心的对称字符对个数(字符串是manacher预处理后的字符串)
那么对于每个中心,我们有
2f[i]−1
2
f
[
i
]
−
1
种方案,答案即为
∑2n−1i=02f[i]−1
∑
i
=
0
2
n
−
1
2
f
[
i
]
−
1
现在我们的问题就是如何计算f数组
我们发现对
f[i]
f
[
i
]
有贡献的一对字符在原字符串数组中的位置之和一定是
i
i
1 2 3 4 5 6 7
s = b c b
s'= # b # c # b #
以为中心的对称字符对 s′[2]=s′[6]=b s ′ [ 2 ] = s ′ [ 6 ] = b ,在 s s 中的位置和为4
那么显然:
f[i]=∑i−1j=1[aj=bi−j]+12 f [ i ] = ∑ j = 1 i − 1 [ a j = b i − j ] + 1 2
括号里就是一个类似卷积的形式,我们可以用FFT来计算
首先考虑 a a 对答案的贡献,把的价值设为1,把 b b 的价值设为0,那么上式可化为:
再考虑 b b 对答案的贡献,把的价值设为1,把 a a 的价值设为0,计算上式:
两者相加再加1最后除以2取下整就是 f f 了
tip
计算f的时候要开ll
注意数组的大小是
fn
f
n
数组的大小注意啦
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=300010;
const ll p=1e9+7;
const double pi=acos(-1.0);
struct node{
double x,y;
node (double xx=0,double yy=0) {
x=xx;y=yy;
}
};
node a[N],b[N],o[N],_o[N];
node operator +(const node &a,const node &b) {return node(a.x+b.x,a.y+b.y);}
node operator -(const node &a,const node &b) {return node(a.x-b.x,a.y-b.y);}
node operator *(const node &a,const node &b) {return node(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);}
int fn,n,RL[N];
char s[N],ss[N];
ll f[N],er[N];
void init(int n) {
for (int i=0;i<n;i++) {
o[i]=node(cos(2.0*i*pi/n),sin(2.0*i*pi/n));
_o[i]=node(cos(2.0*i*pi/n),-sin(2.0*i*pi/n));
}
}
void FFT(int n,node *a,node *w) {
int i,j=0,k;
for (i=0;i<n;i++) {
if (i>j) swap(a[i],a[j]);
for (int l=n>>1;(j^=l)<l;l>>=1);
}
for (i=2;i<=n;i<<=1) {
int m=i>>1;
for (j=0;j<n;j+=i)
for (k=0;k<m;k++) {
node z=a[j+m+k]*w[n/i*k];
a[j+m+k]=a[j+k]-z;
a[j+k]=a[j+k]+z;
}
}
}
int prepare() {
ss[0]='@';
for (int i=1;i<=2*n;i+=2) {
ss[i]='#';
ss[i+1]=s[(i+1)/2];
}
ss[n*2+1]='#';
ss[n*2+2]='$';
return n*2+1;
}
ll manacher() {
int len=prepare();
int mx=1,pos=1;
ll ans=0;
for (int i=1;i<=len;i++) {
int j=2*pos-i;
if (i<mx) RL[i]=min(mx-i,RL[j]);
else RL[i]=1;
while (ss[i+RL[i]]==ss[i-RL[i]]) RL[i]++;
if (i+RL[i]>mx) mx=i+RL[i],pos=i;
ans+=(ll)(RL[i]/2); ans%=p;
}
return ans;
}
int main()
{
scanf("%s",s+1);
n=strlen(s+1);
fn=1;
while (fn<=n+n) fn<<=1;
init(fn);
er[0]=1;
for (int i=1;i<fn;i++) er[i]=(er[i-1]*2)%p;
ll ans=0;
for (int i=1;i<=n;i++)
if (s[i]=='a') a[i].x=1;
FFT(fn,a,o);
for (int i=0;i<fn;i++) b[i]=a[i]*a[i];
memset(a,0,sizeof(a));
for (int i=1;i<=n;i++)
if (s[i]=='b') a[i].x=1;
FFT(fn,a,o);
for (int i=0;i<fn;i++) b[i]=b[i]+a[i]*a[i];
FFT(fn,b,_o);
for (int i=0;i<fn;i++)
f[i]=(ll)(b[i].x/fn+0.5);
for (int i=0;i<fn;i++)
ans+=er[f[i]+1>>1]-1,ans%=p;
printf("%lld",(ans-manacher()+p)%p);
return 0;
}