题目大意:
给你一个串,分成偶数段,第
i
i
块的字符串记为。询问有多少种方案,假设分成
k
k
块,满足。
分析:
直接做是比较难的,所以我们考虑把这个串变一下。
假设原串
s=s[0]s[1]s[2]……s[n−1]
s
=
s
[
0
]
s
[
1
]
s
[
2
]
…
…
s
[
n
−
1
]
,
那么
s′=s[0]s[n−1]s[1]s[n−2]s[2]s[n−3]……
s
′
=
s
[
0
]
s
[
n
−
1
]
s
[
1
]
s
[
n
−
2
]
s
[
2
]
s
[
n
−
3
]
…
…
我们发现,原串中分成的一个块,对应改变串的一个回文串。
想想对应的两个块,左边块的最左边字符第一个插入,而右边串的最左边的字符最后插入,这个与这两个块长度无关。
所以相当于把改变串分成若干个偶回文串,求方案数。
考虑直接建回文树,每次从当前位置的某个
fail
f
a
i
l
转移。
显然如果串全是一个字符,跳
fail
f
a
i
l
就是
O(n)
O
(
n
)
的。
有一个性质,每所有fail的长度存在
logn
l
o
g
n
个等差数列。
证明:
对于当前的串
t
t
,考虑他的一半与另一半
t′′
t
″
,如果是奇数回文串,回文中心那个字符加到
t′
t
′
的后面与
t′′
t
″
的前面。
因为
lent′=lent′′
l
e
n
t
′
=
l
e
n
t
″
,有
t′[i]=t′′[lent′−i]
t
′
[
i
]
=
t
″
[
l
e
n
t
′
−
i
]
。
而对于一个长度大于
lent′′
l
e
n
t
″
的
fail
f
a
i
l
,
t′′
t
″
一定是他的一个后缀,也就是
t′
t
′
一定为他的前缀,可以理解为
t′
t
′
串移动了
delta
d
e
l
t
a
位,
delta=lent−lent.fail
d
e
l
t
a
=
l
e
n
t
−
l
e
n
t
.
f
a
i
l
。
假设三个满足上述条件的后缀之间移动了
delta1
d
e
l
t
a
1
和
delta2
d
e
l
t
a
2
位,
有
t[i]=t[i+delta1]
t
[
i
]
=
t
[
i
+
d
e
l
t
a
1
]
且
t[i]=t[i+delta2]
t
[
i
]
=
t
[
i
+
d
e
l
t
a
2
]
,即
t[i+delta1]=t[i+delta2]
t
[
i
+
d
e
l
t
a
1
]
=
t
[
i
+
d
e
l
t
a
2
]
。
那么
t[i]=t[i+delta2−delta1]
t
[
i
]
=
t
[
i
+
d
e
l
t
a
2
−
d
e
l
t
a
1
]
。
所以,每移动
gcd(delta1,delta2)
g
c
d
(
d
e
l
t
a
1
,
d
e
l
t
a
2
)
位,就有一个回文后缀,前提是这个串超过原串的一半。
所以一定有
logn
l
o
g
n
个等差数列。
我们维护一个
delta
d
e
l
t
a
和一个
last
l
a
s
t
,后者为一个等差数列的起点。维护以它作为等差数列末项时每一项的f值之和g[p]。
怎么维护呢?设当前节点为末项的等差数列有b1,b2,b3,其中b1>b2>b3。那么有g[p]=f[i-b1]+f[i-b2]+f[i-b3]。
根据回文串的性质,不难发现S[i-b2,i-d]=S[i-b3,i],S[i-b1,i-d]=S[i-b2,i],那么在g[fail[p]]中就已经包含了f[i-b1]和f[i-b2],只要把f[i-b3]加上就好了。
代码:
#include <iostream>
#include <cstdio>
#include <cmath>
#include <cstring>
#define LL long long
const int maxn=1e6+7;
const LL mod=1e9+7;
using namespace std;
char str[maxn],s[maxn];
int n,cnt;
LL f[maxn],g[maxn];
struct node{
int fail,len,last,delta;
int son[26];
}t[maxn];
void build()
{
cnt=1;
t[0].fail=1;
t[0].len=0;
t[1].fail=0;
t[1].len=-1;
int now=1;
for (int i=1;i<=n;i++)
{
while (s[i]!=s[i-t[now].len-1]) now=t[now].fail;
if (!t[now].son[s[i]-'a'])
{
cnt++;
int k=t[now].fail;
while (s[i]!=s[i-t[k].len-1]) k=t[k].fail;
t[cnt].fail=t[k].son[s[i]-'a'];
t[now].son[s[i]-'a']=cnt;
t[cnt].len=t[now].len+2;
t[cnt].delta=t[cnt].len-t[t[cnt].fail].len;
if (t[cnt].delta==t[t[cnt].fail].delta) t[cnt].last=t[t[cnt].fail].last;
else t[cnt].last=cnt;
}
now=t[now].son[s[i]-'a'];
for (int j=now;j>0;j=t[t[j].last].fail)
{
g[j]=f[i-t[t[j].last].len];
if (t[j].last!=j) g[j]=(g[j]+g[t[j].fail])%mod;
if (i%2==0) f[i]=(f[i]+g[j])%mod;
}
}
}
int main()
{
scanf("%s",str);
n=strlen(str);
int num=0;
for (int i=1;i<=n;i+=2) s[i]=str[num++];
num=n-1;
for (int i=2;i<=n;i+=2) s[i]=str[num--];
f[0]=1;
build();
printf("%lld",f[n]);
}