题目:给串s和串t,求从s中取一个子串,再和t的一个前缀拼接起来能组成回文串的个数,要求从s取的子串比取的t的前缀长。
可以用扩展kmp+回文算法做,也可以用sam替代扩展kmp,但是很难写。
回文算法用马拉车和回文树都可以,但是这里要求的是“从当前位置开始有多少回文串”,马拉车的数组要经过处理才能用,回文树就可以直接用了…
果然回文树各种方面都完爆马拉车,只是有点长。
可以用扩展kmp求出s串的反串每个后缀与t串的最长公共前缀,然后用每个lcp * 下一个位置的回文个数,累加即是答案。
sam就比较麻烦了,一开始思考倒着建s串的sam,但是倒着建无法很好的在自动机上更新出每个位置的答案,只好正着建,然后倒着跑t串,跑到最后再跳link去统计答案。
然而有一个注意的地方就是跑到最后位置的匹配长度要单独记录,因为这个匹配长度可能不是自动机当前节点记录的最长长度,向上跳的时候就没问题了,直接求出当前节点代表的子串个数,再乘当前节点更新出来的回文个数,累加就是答案。
Sam + Pam的ac代码:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 1e6 + 5;
char s[maxn], t[maxn];
//回文树+后缀自动机解法
struct Pam {
int next[maxn][26];
int fail[maxn];
int len[maxn];// 当前节点表示回文串的长度
int num[maxn];// 到当前节点这里有多少本质不同的回文子串
int pa[maxn];
int S[maxn];
int last, n, p;
int newNode(int l) {
memset(next[p], 0, sizeof(next[p]));
len[p] = l;
num[p] = 0;
return p++;
}
void init() {
n = last = p = 0;
newNode(0);
newNode(-1);
S[n] = -1;
fail[0] = 1;
}
int getFail(int x) {
while(S[n - len[x] - 1] != S[n]) {
x = fail[x];
}
return x;
}
int add(int c) {
S[++n] = c;
int cur = getFail(last);
if(!next[cur][c]) {
int now = newNode(len[cur] + 2);
fail[now] = next[getFail(fail[cur])][c];
next[cur][c] = now;
num[now] = num[fail[now]] + 1;
}
last = next[cur][c];
return num[last];
}
void build() {
init();
int len = strlen(s);
for(int i = len - 1; i >= 0; i--) {
pa[i] = add(s[i] - 'a');
}
// for(int i = 0; i < len; i++) {
// printf("%d ", pa[i]);
// }
// printf("\n");
}
} pam;
struct Sam {
int next[maxn << 1][26];
int link[maxn << 1], step[maxn << 1];
ll pa[maxn << 1];
int a[maxn], b[maxn << 1];
int sz, last, len, root;
void init() {
//如多次建立自动机,加入memset操作
root = sz = last = 1;
}
void add(int c, int id) {
int p = last;
int np = ++sz;
last = np;
pa[np] = pam.pa[id];
// printf("%d %d\n", id, pa[np]);
step[np] = step[p] + 1;
while(!next[p][c] && p) {
next[p][c] = np;
p = link[p];
}
if(p == 0) {
link[np] = root;
} else {
int q = next[p][c];
if(step[p] + 1 == step[q]) {
link[np] = q;
} else {
int nq = ++sz;
memcpy(next[nq], next[q], sizeof(next[q]));
step[nq] = step[p] + 1;
link[nq] = link[q];
link[q] = link[np] = nq;
while(next[p][c] == q && p) {
next[p][c] = nq;
p = link[p];
}
}
}
}
void build() {
init();
int len = strlen(s);
for(int i = 0; i < len; i++) {
add(s[i] - 'a', i + 1);
}
for(int i = 1; i <= sz; i++) {
a[step[i]]++;
}
for(int i = 1; i <= step[last]; i++) {
a[i] += a[i - 1];
}
for(int i = 1; i <= sz; i++) {
b[a[step[i]]--] = i;
}
for(int i = sz; i > 1; i--) {
int e = b[i];
pa[link[e]] += pa[e];
}
}
ll run() {
int p = root, now = 0;
int len = strlen(t);
for(int i = len - 1, c; i >= 0; i--) {
c = t[i] - 'a';
if(next[p][c]) {
p = next[p][c];
++now;
continue;
}
while(!next[p][c] && p) {
p = link[p];
}
if(!p) {
p = root;
now = 0;
} else {
now = step[p] + 1;
p = next[p][c];
}
}
ll ans = (now - step[link[p]]) * pa[p];
if(p != root) {
p = link[p];
}
while(p != root) {
ans += (step[p] - step[link[p]]) * pa[p];
p = link[p];
}
return ans;
}
} sam;
void solve() {
pam.build();
sam.build();
printf("%lld\n", sam.run());
}
int main() {
scanf("%s%s", s, t);
solve();
return 0;
}
然后是复制粘贴俩板子就ac的exKMP+Pam做法…真的无脑:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 1e6 + 5;
char s[maxn], s2[maxn];
struct Pam {
int next[maxn][26];
int fail[maxn];
int len[maxn];// 当前节点表示回文串的长度
int num[maxn];// 到当前节点这里有多少本质不同的回文子串
int pa[maxn];
int S[maxn];
int last, n, p;
int newNode(int l) {
memset(next[p], 0, sizeof(next[p]));
len[p] = l;
num[p] = 0;
return p++;
}
void init() {
n = last = p = 0;
newNode(0);
newNode(-1);
S[n] = -1;
fail[0] = 1;
}
int getFail(int x) {
while(S[n - len[x] - 1] != S[n]) {
x = fail[x];
}
return x;
}
int add(int c) {
S[++n] = c;
int cur = getFail(last);
if(!next[cur][c]) {
int now = newNode(len[cur] + 2);
fail[now] = next[getFail(fail[cur])][c];
next[cur][c] = now;
num[now] = num[fail[now]] + 1;
}
last = next[cur][c];
return num[last];
}
void build() {
init();
int lenn = strlen(s);
// for(int i = 0; i < lenn; i++) {
for(int i = lenn - 1; i >= 0; i--) {
pa[i] = add(s[i] - 'a');
// pa[i + 1] += pa[i];
}
}
} pam;
struct KMP {
int next[maxn], extend[maxn];
void getNext(char str[]) {
int i = 0, j, po, len = strlen(str);
next[0] = len; //初始化next[0]
while(str[i] == str[i + 1] && i + 1 < len) i++;
next[1] = i; //计算next[1]
po = 1; //初始化po的位置
for(i = 2; i < len; i++) {
if(next[i - po] + i < next[po] + po) //第一种情况,可以直接得到next[i]的值
next[i] = next[i - po];
else { //第二种情况,要继续匹配才能得到next[i]的值
j = next[po] + po - i;
if(j < 0) j = 0; //如果i>po+next[po],则要从头开始匹配
while(i + j < len && str[j] == str[j + i]) j++;
next[i] = j;
po = i; //更新po的位置
}
}
}
//计算extend数组
void exKmp() {
int len = strlen(s), lent = strlen(s2);
getNext(s2); //计算子串的next数组
int pos = 0;
while(pos < len && pos < lent && s[pos] == s2[pos]) {
++pos;
}
extend[0] = pos;
int k = 0, L;
for(int i = 1; i < len; i++) {
pos = k + extend[k] - 1;
L = next[i - k];
if(i + L <= pos) {
extend[i] = L;
} else {
int j = pos - i + 1;
if(j < 0) {
j = 0;
}
while(i + j < len && j < lent && s[i + j] == s2[j]) {
++j;
}
extend[i] = j;
k = i;
}
}
}
} kmp;
void solve() {
ll ans = 0;
int len = strlen(s);
pam.build();
reverse(s, s + len);
kmp.exKmp();
for(int i = 1; i < len; i++) {
// printf("%d %d\n", kmp.extend[i], pam.cnt[len - i + 2]);
ans += 1LL * pam.pa[len - i] * kmp.extend[i];
}
printf("%lld\n", ans);
}
int main() {
scanf("%s%s", s, s2);
solve();
return 0;
}