题目描述
传送门
题目大意: 从只含a,b的字符串中选出一个子序列,满足
1.位置和字符都关于某条对称轴对称
2.不能是连续的一段
求合法的子序列个数。
题解
不能是连续的一段,对于连续一段的答案可以用manacher求解.
关于位置和字符对称的问题,我们可以对于a,b分开考虑,要计算以一个位置为对称轴的对数。
设当前计算的字符为a,那么把所有是a的位置赋值成1,即
f[i]=1
如果两个位置关于x对称,那么一定满足
f[x−k]=f[x+k]
,下标是定值,可以用FFT求解啊。
那么对于每个位置,设对称的a,b个数和为cnt,那么能产生的合法的序列个数就是
2cnt−1
代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#define N 300030
#define LL long long
#define p 1000000007
#define pi acos(-1)
using namespace std;
int n,n1,len,ch[N],cnt[N],t[N],L,R[N];
char s[N]; int mi[N];
struct data{
double x,y;
data(double X=0,double Y=0) {
x=X,y=Y;
}
}f[N],g[N];
data operator +(data a,data b){
return data(a.x+b.x,a.y+b.y);
}
data operator -(data a,data b){
return data(a.x-b.x,a.y-b.y);
}
data operator *(data a,data b)
{
return data(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);
}
void manacher(int n)
{
int mx=0,id=0;
for (int i=1;i<=n;i++){
if (mx>=i) t[i]=min(t[2*id-i],mx-i);
else t[i]=0;
for (;i-t[i]-1>0&&i+t[i]+1<=n&&ch[i-t[i]-1]==ch[i+t[i]+1];t[i]++);
if (t[i]+i>=mx) mx=t[i]+i,id=i;
}
}
void FFT(data a[N],int n,int opt)
{
for (int i=0;i<n;i++)
if (i>R[i]) swap(a[i],a[R[i]]);
for (int i=1;i<n;i<<=1) {
data wn=data(cos(pi/i),opt*sin(pi/i));
for (int p1=i<<1,j=0;j<n;j+=p1){
data w=data(1,0);
for (int k=0;k<i;k++,w=w*wn){
data x=a[j+k]; data y=w*a[j+k+i];
a[j+k]=x+y; a[j+k+i]=x-y;
}
}
}
if (opt==-1)
for (int i=0;i<n;i++) a[i].x/=n;
}
int main()
{
gets(s+1); n=strlen(s+1); int ans=n;
for (int i=1;i<=n;i++) ch[2*i-1]=0,ch[2*i]=s[i]-'a'+1;
ch[2*n+1]=0;
manacher(2*n+1);
for (int i=1;i<=2*n+1;i++) ans+=t[i]/2,ans%=p;
len=2*(n-1); mi[0]=1;
for (int i=1;i<=n+10;i++) mi[i]=mi[i-1]*2%p;
for (n1=1;n1<=len;n1<<=1) L++;
for (int i=0;i<=n1;i++) R[i]=(R[i>>1]>>1)|((i&1)<<(L-1));
for (int i=0;i<n;i++) f[i].x=(s[i+1]=='a');
FFT(f,n1,1);
for (int i=0;i<n1;i++) f[i]=f[i]*f[i];
FFT(f,n1,-1);
for (int i=0;i<n;i++) g[i].x=(s[i+1]=='b');
FFT(g,n1,1);
for (int i=0;i<n1;i++) g[i]=g[i]*g[i];
FFT(g,n1,-1);
for (int i=0;i<n1;i++) cnt[i]=(int)(f[i].x+g[i].x+0.1);
int ans1=0;
for (int i=0;i<n1;i++) ans1+=mi[(cnt[i]+1)/2]-1,ans1%=p;
ans1=(ans1-ans+p)%p;
printf("%d\n",ans1);
}