题目大意
给你一个长为n的字符串S,现在你要把他划分成k段,记为p1p2…pk,其中对于任意1<=i<=k,满足
pi=pk−i−1
p
i
=
p
k
−
i
−
1
,且k为偶数。问划分方案数。
n<=1e6
解题思路
又是一道喜闻乐见的真·字符串题。可怜老年选手回文树都打不对了。
直接做可以设一个dp,f[i]表示做完前i位和后i位,划分的方案,然后每次枚举新的位置判断合法后转移即可。
但这样十分麻烦,感觉没有什么优化空间。
考虑进行字符串变换。
设新字符串
a=s1sns2sn−1...sn/2sn/2+1
a
=
s
1
s
n
s
2
s
n
−
1
.
.
.
s
n
/
2
s
n
/
2
+
1
,那么问题变成了回文串划分方案。
同样可以暴力dp。
现在考虑如何优化,题解告诉我们用“eertree”,有人甚至写过一篇论文,但没什么适用性。
放到回文树上考虑,假设前缀a[1..i]的最长回文后缀为lst,那么你从lst一直fail到根,把这些回文串转移一下就行。这样时间复杂度就变成了O(回文串个数)。
我们知道一个回文串s的最长回文后缀suf就是最长回文前缀pre,那么suf就是一个border,考虑到border能够划分为O(logn)段等差数列,考虑在回文树上维护等差数列的位置的f的和,到时候就可以log时间转移了。考虑设nxt[i]表示回文树点i沿fail到下一个等差数列的位置,我们维护i到nxt[i](不含nxt[i])的f和。
我们考虑首项最长一段等差数列,即lst所在。假设有
b1,b2,b3(b1>b2>b3)
b
1
,
b
2
,
b
3
(
b
1
>
b
2
>
b
3
)
,设公差为d,那么对应的f就是
f[i−b1],f[i−b2],f[i−b3]
f
[
i
−
b
1
]
,
f
[
i
−
b
2
]
,
f
[
i
−
b
3
]
,考虑到在i-d的位置,我们已经有了
f[i−d−b2],f[i−d−b3]
f
[
i
−
d
−
b
2
]
,
f
[
i
−
d
−
b
3
]
的和(实际上),存在了
b2
b
2
对应的点上,那么我们只要把
f[i−b3]
f
[
i
−
b
3
]
的值加进来即可。这就做完了
代码
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<map>
using namespace std;
typedef long long ll;
typedef double db;
#define fo(i,j,k) for(i=j;i<=k;i++)
#define fd(i,j,k) for(i=j;i>=k;i--)
#define cmax(a,b) (a=(a>b)?a:b)
#define cmin(a,b) (a=(a<b)?a:b)
const int N=1e6+5,mo=1e9+7,rt=3;
char s[N];
int a[N],fail[N],nxt[N],len[N],tr[N][26],diff[N],tt,f[N],g[N];
int n,i,lst[N],p;
int get(int x,int y)
{
while (a[i-len[x]-1]!=y)
x=fail[x];
return x;
}
int ins(int x,int y)
{
x=get(x,y);
if (!tr[x][y])
{
len[++tt]=len[x]+2;
int z=get(fail[x],y);
z=tr[z][y];
if (!z) z=2;
fail[tt]=z;
diff[tt]=len[tt]-len[fail[tt]];
if (diff[tt]==diff[fail[tt]]) nxt[tt]=nxt[fail[tt]];
else nxt[tt]=fail[tt];
if (nxt[tt]<=2) nxt[tt]=0;
tr[x][y]=tt;
}
return tr[x][y];
}
int main()
{
freopen("t4.in","r",stdin);
//freopen("t4.out","w",stdout);
scanf("%s",s+1);
n=strlen(s+1);
fo(i,1,n/2)
{
a[i*2-1]=s[i]-'a';
a[i*2]=s[n-i+1]-'a';
}
a[0]=-1;
a[n+1]=-1;
len[1]=-1;
len[2]=0;
fail[1]=nxt[1]=1;
fail[2]=1;
nxt[2]=1;
diff[2]=1;
tt=2;
f[0]=1;
lst[0]=2;
fo(i,1,n)
{
lst[i]=ins(lst[i-1],a[i]);
for(p=lst[i];p>2;p=nxt[p])
{
g[p]=f[i-(len[nxt[p]]+diff[p])];
if (diff[p]==diff[fail[p]])
(g[p]+=g[fail[p]])%=mo;
if (i%2==0) (f[i]+=g[p])%=mo;
}
}
printf("%d",f[n]);
}