我能不能扔发链接跑啊qaq,这题太神仙了我好像又不会做了….
问把一个字符串S分成 a1,a2....an(2|n) 这n段,满足 a1=an,a2=an−1..... 的方案数
先模型转化,我们构造一个新的串
S0Sn−1S1Sn−2S3.....
,然后你会惊讶的发现,问题等价于问把新串分成若干个偶数长度的回文串的方案数(易证)
那么这个新问题要怎么做呢
首先有两个结论
1.一个回文串S的后缀T如果是回文串等价于T是S的border
2.将一个串S的所有border按长度从小到大排序后,能形成log个等差数列
1还是挺好证的,2挺复杂的我也不会证,就当他是个很优美自然的结论好了qaq
有了这两个结论后,这个问题就可以用回文树做了
我们从左往右枚举S的所有前缀1~i
用f[i]表示将1~i分成若干个偶数回文串的方案数
我们对回文树上每个节点p,维护g[p]表示p这个回文串在1~i的前缀中最后一次出现作为某个等差数列(
F=a1,a2,a3...an(ai<ai+1)
)的最后一项
an
时,
f[i−a1]+f[i−a2]+f[i−a3]....
的和
next[p]表示前一个等差数列最后一项是哪个回文串q
注意到
F=a1,a2,a3
这个等差数列,因为p=a1会fail到p=a2,又因为回文串的border的性质,有
S[i−a1,i−d]=S[i−a2,i],S[i−a2,i−d]=S[i−a3,i]
,因此
f[i−a1],f[i−a2]
已经在i-d被fail[p]统计过了(可以证明fail[p]一定是a2且fail[p]一定在i-d作为这个等差数列的尾项,证明过程此处略去),那么g[p]只要在g[fail[p]]的基础上再加上f[i-a3]这个值就行了
枚举到前缀i的时候一边跳i的log段等差数列一边维护每个尾项的g[p],复杂度 O(|S|log|S|)
code:
#include<set>
#include<map>
#include<deque>
#include<queue>
#include<stack>
#include<cmath>
#include<ctime>
#include<bitset>
#include<string>
#include<vector>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<climits>
#include<complex>
#include<iostream>
#include<algorithm>
#define ll long long
using namespace std;
const int maxn = 1100000;
const int mod = 1e9+7;
int n,m;
char S[maxn]; int str[maxn];
int id[maxn],f[maxn],g[maxn];
struct pTree
{
int fail[maxn],trans[maxn][26],diff[maxn],nxt[maxn],len[maxn];
int cnt,last;
void init()
{
cnt=1; last=0;
memset(trans[0],0,sizeof trans[0]);
memset(trans[1],0,sizeof trans[1]);
len[0]=0,len[1]=-1;
fail[0]=1;
}
int newnode(int l)
{
++cnt;
memset(trans[cnt],0,sizeof trans[cnt]);
len[cnt]=l;
return cnt;
}
int extend(int i)
{
int p=last,np,w=str[i];
while(str[i-len[p]-1]!=str[i]) p=fail[p];
if(!trans[p][w])
{
np=newnode(len[p]+2);
int t=fail[p];
while(str[i-len[t]-1]!=str[i]) t=fail[t];
fail[np]=trans[t][w];
diff[np]=len[np]-len[fail[np]];
if(diff[np]==diff[fail[np]]) nxt[np]=nxt[fail[np]];
else nxt[np]=fail[np];
trans[p][w]=np;
}
else np=trans[p][w];
return last=np;
}
}tr;
int main()
{
//freopen("tmp.in","r",stdin);
//freopen("tmp.out","w",stdout);
scanf("%s",S); n=strlen(S);
if(n&1) return puts("0"),0;
str[m=0]=-1;
for(int i=0;i<n/2;i++) str[++m]=S[i]-'a',str[++m]=S[n-i-1]-'a';
tr.init();
for(int i=1;i<=m;i++) id[i]=tr.extend(i);
f[0]=1;
for(int i=1;i<=m;i++)
{
for(int p=id[i];p;p=tr.nxt[p])
{
g[p]=f[i-tr.len[tr.nxt[p]]-tr.diff[p]];
if(tr.diff[p]==tr.diff[tr.fail[p]]) (g[p]+=g[tr.fail[p]])%=mod;
if(!(i&1)) (f[i]+=g[p])%=mod;
}
}
printf("%d\n",f[n]);
return 0;
}