bzoj3160 万径人踪灭
原题地址:http://www.lydsy.com/JudgeOnline/problem.php?id=3622
题意:
给定一个长度为N且只有a,b的字符串,问有多少种方案从中选取一个子序列,使得:
1.位置和字符都关于某条对称轴对称。
2.不能是连续的一段。
数据范围
n<=100000
题解:
首先求不连续的,就先求所有的,再用manacher求连续的剪掉即可。
关键大概就是想到以每个位置为中心,两边对称相同的有多少对
fi
f
i
,这个中心的答案就是
2fi−1
2
f
i
−
1
。
fi=∑ij=0[si−j==si+j]
f
i
=
∑
j
=
0
i
[
s
i
−
j
==
s
i
+
j
]
当然有夹缝处,于是我们把i数组下标扩大一倍:
fi=∑[j+k==i] [sj==sk]
f
i
=
∑
[
j
+
k
==
i
]
[
s
j
==
s
k
]
,这是一个卷积的形式。
介于0x0=0,1x0=0,1x1=1,
我们想到了规定某一种字符系数为1,然后fft做多项式乘法,这样只有相同的才会贡献。
对a做一遍FFT,对b做一遍FFT,加一起再IDFT,得到 f f <script type="math/tex" id="MathJax-Element-10">f</script>数组,减去manacher的答案即可。
代码:
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
#include<cmath>
#define LL long long
using namespace std;
const int mod=1000000007;
const int MXN=530000;
const double Pi=acos(-1);
struct Virt
{
double r,i;
Virt(){}
Virt(double r,double i):r(r),i(i){}
Virt operator+(const Virt &A){return Virt(r+A.r,i+A.i);}
Virt operator-(const Virt &A){return Virt(r-A.r,i-A.i);}
Virt operator*(const Virt &A){return Virt(r*A.r-i*A.i,r*A.i+i*A.r);}
}omg[MXN],_omg[MXN],a[MXN],b[MXN],c[MXN];
int len,n,p,R[MXN],po[MXN],v[MXN];
char s[MXN];
void FFT(Virt *x,int opt)
{
Virt *w; if(opt==1) w=omg; else w=_omg;
for(int i=0;i<n;i++) if(i<R[i]) swap(x[i],x[R[i]]);
for(int m=2;m<=n;m=m<<1)
{
int l=m>>1;
for(int j=0;j<n;j+=m)
{
for(int i=0;i<l;i++)
{
Virt y=x[j+i+l]*w[n/m*i];
x[j+i+l]=x[j+i]-y;
x[j+i]=x[j+i]+y;
}
}
}
if(opt==-1) for(int i=0;i<n;i++) x[i].r=x[i].r/(double)n;
}
int manacher()
{
for(int i=len;i>=0;i--) s[2*i+1]='#',s[2*i+2]=s[i]; s[0]='+'; s[2*len+2]='-';
int mx=0,id=0; int ret=0;
for(int i=1;i<=2*len;i++)
{
if(id+mx>i) v[i]=min(v[2*id-i],id+mx-i);
else v[i]=1;
while(s[i+v[i]]==s[i-v[i]]) v[i]++;
if(i+v[i]>id+mx) id=i,mx=v[i];
ret=(ret+v[i]/2)%mod;
}
return ret;
}
int main()
{
scanf("%s",s); len=strlen(s);
int m; for(p=0,n=1,m=2*len;n<m;n=n<<1,p++);
po[0]=1; for(int i=1;i<=n;i++) po[i]=(po[i-1]*2)%mod;
R[0]=0; for(int i=1;i<n;i++) {R[i]=(R[i>>1]>>1)|((i&1)<<(p-1));}
for(int i=0;i<n;i++){omg[i]=Virt(cos(2.0*Pi/(double)n*i),sin(2.0*Pi/(double)n*i)); _omg[i]=Virt(omg[i].r,-omg[i].i);}
for(int i=0;i<n;i++) {a[i]=Virt((s[i]=='a'),0); b[i]=Virt((s[i]=='b'),0);}
FFT(a,1); FFT(b,1);
for(int i=0;i<n;i++) c[i]=a[i]*a[i]+b[i]*b[i];
FFT(c,-1);
int ret=0; for(int i=0;i<n;i++) {ret=(ret+po[((int)(c[i].r+0.5)+1)/2])%mod; ret--; if(ret<0) ret+=mod;}
ret=(ret-manacher()+mod)%mod;
printf("%d\n",ret);
return 0;
}