扩展 KMP
简述
给你两个字符串 a a a, b b b,长度分别为 n n n, m m m。
请输出 b b b 的每一个后缀与 b b b 的最长公共前缀以及 a a a 的每一个后缀的最长公共前缀。
扩展 KMP
:求出字符串
a
a
a 的所有后缀与
b
b
b 的最长公共前缀长度,时间复杂度
O
(
∣
S
∣
+
∣
T
∣
)
\mathcal{O}(|S|+|T|)
O(∣S∣+∣T∣)。
该解法思想与 KMP
类似,所以称作扩展 KMP
。
求法
定义 n e x t i next_i nexti 为 b b b 由 i i i 开始的后缀与 b b b 的最长公共前缀长度。(即为自配)
定义 e x t i ext_i exti 为 a a a 由 i i i 开始的后缀与 b b b 的最长公共前缀长度。
这是本题的第一问,也是第二问的辅助数组。
易得,此时 n e x t 0 = ∣ b ∣ next_0=|b| next0=∣b∣。
思考,如果已知 n e x t 0 ∼ n e x t x − 1 next_0 \sim next_{x-1} next0∼nextx−1,如何求 n e x t x ? next_x \ ? nextx ?
有一个较具体的例子,如果 k = 121 k=121 k=121,那么 n e x t 0 next_{0} next0 到 n e x t 120 next_{120} next120 都已经计算完毕,且 l 120 = 100 l_{120}=100 l120=100, r 120 = 130 r_{120}=130 r120=130。
这意味着 b [ 100..130 ] = b [ 0..3 ] b[100..130]=b[0..3] b[100..130]=b[0..3],那么 b [ 121..130 ] = b [ 21..30 ] b[121..130]=b[21..30] b[121..130]=b[21..30],这样 n e x t 21 next_{21} next21 对于计算 n e x t 121 next_{121} next121 就非常有帮助,如果 n e x t 21 = 3 next_{21}=3 next21=3,那么 n e x t 121 = 3 next_{121}=3 next121=3。
设 n e x t 0 ∼ n e x t k next_0 \sim next_k next0∼nextk 已经算好,记 p p p 为在以前的匹配过程中在 b b b 串中的最远位置,即 p = m a x ( i + n e x t [ i ] − 1 ) p=max(i+next[i]-1) p=max(i+next[i]−1),其中 i = 1... k i=1...k i=1...k。
设取到这个最大值 p p p 的位置是 p 0 p0 p0。
则 a [ p 0... p ] = b [ 0... p − p 0 ] a[p0...p]=b[0...p−p0] a[p0...p]=b[0...p−p0]。
假设 i + n e x t [ i − k ] ≤ p i+next[i-k]≤p i+next[i−k]≤p,则 n e x t x = n e x t [ x − k ] next_x=next[x-k] nextx=next[x−k]。
否则,暴力枚举 n e x t x next_x nextx。
求 e x t ext ext 同理,若 i + n e x t [ i − k ] ≤ p i+next[i-k]≤p i+next[i−k]≤p,则 e x t x = n e x t [ x − k ] ext_x=next[x-k] extx=next[x−k]。
否则,暴力枚举 e x t x ext_x extx。
由于此时的 k k k 不降,所以时间复杂度为 O ( ∣ S ∣ ) \mathcal O(|S|) O(∣S∣)。
t o b e c o n t i n u e \Huge{to \ be \ continue} to be continue
代码实现
#include <bits/stdc++.h>
using namespace std;
#define _ (int) 3e7 + 5
char a[_], b[_];
int n, m, nxt[_], ext[_];
long long Ans1, Ans2;
void get_nxt()
{
nxt[0] = m;
int j = 0;
while(j + 1 < m && b[j] == b[j + 1]) ++j;
nxt[1] = j;
int k = 1;
for(int i = 2; i < m; ++i)
{
int p = k + nxt[k] - 1;
if(i + nxt[i - k] <= p) nxt[i] = nxt[i - k];
else
{
j = max(p - i + 1, 0);
while(i + j < m && b[i + j] == b[j]) ++j;
nxt[i] = j;
k = i;
}
}
}
void get_ext()
{
int j = 0;
while(j < n && j < m && a[j] == b[j]) ++j;
ext[0] = j;
int k = 0;
for(int i = 1; i < n; ++i)
{
int p = k + ext[k] - 1;
if(i + nxt[i - k] <= p) ext[i] = nxt[i - k];
else
{
j = max(p - i + 1, 0);
while(i + j < n && j < m && a[i + j] == b[j]) ++j;
ext[i] = j;
k = i;
}
}
}
signed main()
{
scanf("%s%s", a, b);
n = strlen(a);
m = strlen(b);
get_nxt();
get_ext();
for(int i = 0; i < m; ++i) Ans1 ^= 1ll * (i + 1) * (nxt[i] + 1);
for(int i = 0; i < n; ++i) Ans2 ^= 1ll * (i + 1) * (ext[i] + 1);
printf("%lld\n%lld\n", Ans1, Ans2);
return 0;
}