本文版权归ljh2000和博客园共有,欢迎转载,但须保留此声明,并给出原文链接,谢谢合作。
本文作者:ljh2000
作者博客:http://www.cnblogs.com/ljh2000-jump/
转载请注明出处,侵权必究,保留最终解释权!
题目链接:BZOJ3160
正解:FFT+manacher
解题报告:
参考博客:戳这里
题目求的是一个字符串的不连续回文子序列个数。
考虑用所有的回文子序列个数$-$连续回文子序列就是答案。
求连续回文子序列的个数只需要跑一遍$manacher$,然后得到以每个点为对称中心的$p$数组之后,可以直接统计出答案。
回文子序列的个数似乎不好考虑,我们不妨考虑以每个地方(包括间隔)为对称点的回文子序列个数。
我们如果知道了两边对应位置相等的个数有$x$个,根据二项式定理$C(n,1)+C(n,2)+C(n,3)+…+C(n,n)=2^n-1$,所以答案就是$2^x-1$。
而$a$、$b$是彼此独立的,所以我们可以分别考虑$a$和$b$。
我们设出一个多项式,若这一位是$a$那么系数就是$1$,容易发现把这个多项式平方之后,$i$项对应的系数就是以$i$为对称中心的相等的$a$的个数。
因为我一直写的是递归版的$FFT$,然后被卡常了...
拖了一个非递归版的$FFT$就愉快地$AC$了。
//It is made by ljh2000
#include <iostream>
#include <cstdlib>
#include <cstring>
#include <cstdio>
#include <cmath>
#include <algorithm>
#include <ctime>
#include <vector>
#include <queue>
#include <map>
#include <set>
#include <string>
#include <complex>
using namespace std;
typedef long long LL;
typedef complex<double> C;
const int MOD = 1000000007;
const double pi = acos(-1);
const int MAXN = 300011;
int n,L,f[MAXN],mx,pos,m,p[MAXN];
char ch[MAXN],s[MAXN];
C a[MAXN],b[MAXN],aa[MAXN],bb[MAXN];
int ans[MAXN],tot,out,er[MAXN],R[MAXN];
//ans[i]表示以i为对称中心的两边的对称字符数量(包含i)
inline int getint(){
int w=0,q=0; char c=getchar(); while((c<'0'||c>'9') && c!='-') c=getchar();
if(c=='-') q=1,c=getchar(); while (c>='0'&&c<='9') w=w*10+c-'0',c=getchar(); return q?-w:w;
}
inline LL fast_pow(LL x,LL y){
LL r=1;
while(y>0) {
if(y&1) r*=x,r%=MOD;
x*=x; x%=MOD;
y>>=1;
}
return r;
}
inline void fft(C *a,int n,int f){
for(int i=0;i<n;i++) if(i<R[i]) swap(a[i],a[R[i]]);//交换位置
for(int i=1;i<n;i<<=1){//待合并区间长度
C wn(cos(pi/i),sin(f*pi/i)),x,y;//这里就不用再*2了,因为合并后的区间长度是i的两倍
for(int j=0;j<n;j+=i<<1){//起始位置
C w(1,0);
for(int k=0;k<i;k++,w*=wn){//第k个
x=a[j+k];y=w*a[j+i+k];
a[j+k]=x+y;
a[j+i+k]=x-y;
}
}
}
}
inline LL manacher(){
pos=0; mx=0; s[0]='%'; s[1]='#'; m=1;
for(int i=0;i<n;i++) s[++m]=ch[i],s[++m]='#';
for(int i=1;i<=m;i++) {
if(i<mx) p[i]=min(p[2*pos-i],mx-i); else p[i]=1;
for(;i+p[i]<=m/*!!!*/ && s[i+p[i]]==s[i-p[i]];p[i]++);
if(i+p[i]>mx) { mx=i+p[i]; pos=i; }
tot+=p[i]/2;
tot%=MOD;//一个回文串的贡献
}
return tot;
}
inline void work(){
scanf("%s",ch); n=strlen(ch); int N=n<<1,ll=0;
for(int i=0;i<=N;i++) er[i]=fast_pow(2,i);
for(int i=0;i<n;i++) if(ch[i]=='a') a[i]=b[i]=1;
for(L=1;L<=N;L<<=1) ll++; for(int i=0;i<L;i++) R[i]=(R[i>>1]>>1)|((i&1)<<(ll-1));
fft(a,L,1); fft(b,L,1); for(int i=0;i<L;i++) a[i]*=b[i];
fft(a,L,-1); for(int i=0;i<N;i++) ans[i]=(int)(a[i].real()/L+0.5);
for(int i=0;i<n;i++) if(ch[i]=='b') aa[i]=bb[i]=1;
fft(aa,L,1); fft(bb,L,1); for(int i=0;i<L;i++) aa[i]*=bb[i]; fft(aa,L,-1);
for(int i=0;i<N;i++) ans[i]+=(int)(aa[i].real()/L+0.5);
for(int i=0;i<N;i++) ans[i]=er[(ans[i]+1)/2]-1;
for(int i=0;i<N;i++) out+=ans[i],out%=MOD;
out-=manacher(); out+=MOD; out%=MOD;
printf("%d",out);
}
int main()
{
work();
return 0;
}