给定一个串,把串分为偶数段
假设分为了s1,s2,s3…sk
求,满足s1=sk,s2=sk−1… 的方案数。
当你把原串S变成S′=S[1]S[n]S[2]S[n−2]… 后,这题的问题就变成了有多少种切分S’的方法使得每一段都是回文串。
(这个转化已经很神仙了,这道题就出到这里我都觉得已经很nb了)
回文切分方法这个东西是一个经典的n方dp,dp[i]表示当前前缀i有多少种切分方法,dp[i]要从i的所有后缀回文位置dp[j]转移过来,累加就是当前dp[i]的答案。
这时候你就需要一个玩意儿帮你搞定当前前缀i位置的所有回文后缀。考虑使用回文树,首先找到前缀i位置对应的回文树节点,之后向上一条fail链就是当前位置的所有后缀回文,当fail树整个是一条链的时候,复杂度退化到n方。
然而这题的串长1e6,n方显然是远远做不动的,需要想办法优化。
怎么办呢?翻开金策-字符串算法选讲,翻到23页,就提到了这个dp的优化,而且讲了一些很nb的知识…
s 是回文串, 则 s 的后缀 t 是回文串当且仅当 t 是 s border。
也就是说如果一个回文串的后缀也是一个回文,那么这个后缀一定跟该回文串的相同长度前缀相等。
既然这样就可以利用一下border的性质。
一个字符串的所有回文后缀的长度可以表示成 O(logn)个等差数列。
忽然很nb的做法就诞生了…
根据这两条结论,我们发现一条fail链的组成是有规律的,它可以划分成logn个等差数列。
借用这篇博客的一张图:
这是一组等差回文后缀,设当前位置为i,最长的一个后缀长度为a1,公差为d,那么根据border的性质可以知道s[i - a1,i - d]是一个回文串,s[i - a2,i - d]也是一个回文串…一直到等差数列的最后一项,这个时候[i - an,i - d]就不再是回文串了。
也就是说,在i这个位置,我们想求的分割数,其实和其在回文树上的fail位置之差了一个等差数列的末项没有统计(前提是fail还在同一个等差数列内)!
有了这个结论之后,就可以愉快的按等差分组,一组一组的转移了。
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
const int maxn = 1e6 + 5;
const int mod = 1e9 + 7;
char s[maxn], t[maxn];
struct Pam {
int next[maxn][26];
int fail[maxn];
int diff[maxn], par[maxn];
int pos[maxn];
int f[maxn], g[maxn];
int len[maxn];
int S[maxn];
int last, n, p;
int newNode(int l) {
memset(next[p], 0, sizeof(next[p]));
len[p] = l;
return p++;
}
void init() {
n = last = p = 0;
newNode(0);
newNode(-1);
S[n] = -1;
fail[0] = 1;
}
int getFail(int x) {
while (S[n - len[x] - 1] != S[n]) {
x = fail[x];
}
return x;
}
void add(int c) {
S[++n] = c;
int cur = getFail(last);
if (!next[cur][c]) {
int now = newNode(len[cur] + 2);
fail[now] = next[getFail(fail[cur])][c];
next[cur][c] = now;
//按等差分组,par接到上一个等差数列的末项。
diff[now] = len[now] - len[fail[now]];
if (diff[now] != diff[fail[now]]) {
par[now] = fail[now];
} else {
par[now] = par[fail[now]];
}
}
last = next[cur][c];
}
void build() {
init();
for (int i = 1; s[i]; i++) {
add(s[i] - 'a');
pos[i] = last;
}
}
void solve() {
int lenn = strlen(s + 1);
f[0] = 1;
for (int i = 1; i <= lenn; ++i) {
for (int p = pos[i]; p > 1; p = par[p]) {
//加上首项
g[p] = f[i - len[par[p]] - diff[p]];
if (diff[p] == diff[fail[p]]) {
g[p] = (g[p] + g[fail[p]]) % mod;
}
if (i % 2 == 0) {
f[i] = (f[i] + g[p]) % mod;
}
}
}
printf("%d\n", f[lenn]);
}
} pam;
int main() {
scanf("%s", t + 1);
int len = strlen(t + 1);
for (int i = 1; i + i <= len; ++i) {
s[i * 2 - 1] = t[i];
s[i * 2] = t[len - i + 1];
}
pam.build();
pam.solve();
return 0;
}