EXKMP
介绍
EXKMP主要是利用于解决处理两个字符串的最长公共前缀长度,假如A是主串,B是副串,那么这时我们定义一个ex数组,ex[i]就表示A[i…Alen]和B[1…Blen]的最长公共前缀长度
模板
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<algorithm>
#include<cmath>
using namespace std;
char sa[1100000],sb[1100000];
int lena,lenb;
int p[1100000],ex[1100000];
//p数组是用来让B串自己匹配自己的
void exkmp()
{
p[1]=lenb;
int x=1;
while(sb[x]==sb[x+1]&&x+1<=lenb) x++;//因为我们p[1]是具有一定性,所以我们不能直接用,所以要先暴力求出p[2]
p[2]=x-1;
int k=2;
for(int i=3;i<=lenb;i++)
{
int pp=k+p[k]-1,L=p[i-k+1];//pp实际上是p
if(i+L<pp+1) p[i]=L;//i-k+L<pp-k+1化简后i+L<pp
else
{
int j=pp-i+1;
if(j<0) j=0;
while(sb[j+1]==sb[i+j]&&i+j<=lenb) j++;
p[i]=j;
k=i;
}
}
x=1;
while(sa[x]==sb[x]&&x<=lenb) x++;//ex[1]并不具有一定性,所以我们暴力求出ex[1]
ex[1]=x-1;
k=1;
for(int i=2;i<=lena;i++)
{
int pp=k+ex[k]-1,L=p[i-k+1];
if(i+L<pp+1) ex[i]=L;
else
{
int j=pp-i+1;
if(j<0) j=0;
while(sb[j+1]==sa[i+j]&&i+j<=lena&&j<=lenb) j++;
ex[i]=j;
k=i;
}
}
}
int main()
{
scanf("%s%s",sa+1,sb+1);
lena=strlen(sa+1);lenb=strlen(sb+1);
exkmp();
for(int i=1;i<lena;i++) printf("%d ",ex[i]);
printf("%d\n",ex[lena]);
return 0;
}
例题
#include <bits/stdc++.h>
//#pragma GCC optimize(2)
#define int long long
using namespace std;
//const int mod = 998244353;
typedef long long LL;
const int inf = 1e18;
const int mod = 1e9 + 7;
const int maxn = 1e5 + 10;
const int N = 1e7 + 10000;
char sa[1100000], sb[1100000];
int lena, lenb;
int p[1100000], ex[1100000];
//p数组是用来让B串自己匹配自己的
void exkmp() {
p[1] = lenb;
int x = 1;
while (sb[x] == sb[x + 1] && x + 1 <= lenb) x++;//因为我们p[1]是具有一定性,所以我们不能直接用,所以要先暴力求出p[2]
p[2] = x - 1;
int k = 2;
for (int i = 3; i <= lenb; i++) {
int pp = k + p[k] - 1, L = p[i - k + 1];//pp实际上是p
if (i + L < pp + 1) p[i] = L;//i-k+L<pp-k+1化简后i+L<pp
else {
int j = pp - i + 1;
if (j < 0) j = 0;
while (sb[j + 1] == sb[i + j] && i + j <= lenb) j++;
p[i] = j;
k = i;
}
}
x = 1;
while (sa[x] == sb[x] && x <= lenb) x++;//ex[1]并不具有一定性,所以我们暴力求出ex[1]
ex[1] = x - 1;
k = 1;
for (int i = 2; i <= lena; i++) {
int pp = k + ex[k] - 1, L = p[i - k + 1];
if (i + L < pp + 1) ex[i] = L;
else {
int j = pp - i + 1;
if (j < 0) j = 0;
while (sb[j + 1] == sa[i + j] && i + j <= lena && j <= lenb) j++;
ex[i] = j;
k = i;
}
}
}
void solve() {
cin>>sb+1>>sa+1;
lena = strlen(sa + 1);
lenb = strlen(sb + 1);
exkmp();
int sum = 0;
for (int i = 1; i <=ex[1]; i++) sum+=ex[i+1];
cout << sum << "\n";
}
signed main() {
// ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
int _ = 1;
// cin >> _;
while (_--) {
solve();
}
return 0;
}
#pragma GCC optimize(2)
#include <bits/stdc++.h>
#define all(n) (n).begin(), (n).end()
#define se second
#define fi first
#define pb emplace_back
#define mp make_pair
#define sqr(n) ((n)*(n))
#define rep(i, a, b) for (int i = (a); i <= (b); ++i)
#define per(i, a, b) for (int i = (a); i >= (b); --i)
#define precision(a) setiosflags(ios::fixed) << setprecision(a)
#define IOS ios::sync_with_stdio(false);cin.tie(0);cout.tie(0)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> PII;
typedef pair<ll, ll> PLL;
typedef vector<int> VI;
typedef vector<long long> VL;
typedef double db;
template <typename T> inline char read(T &x) {
x = 0; T fg = 1; char ch = getchar();
while (!isdigit(ch)) {if (ch == '-') fg = -1;ch = getchar();}
while (isdigit(ch)) x = (x << 3) + (x << 1) + (ch ^ '0'), ch = getchar();
x = fg * x; return ch;
}
template <typename T, typename... Args> inline void read(T &x, Args &... args) { read(x), read(args...); }
template <typename T> inline void write(T x) {
int len = 0; char c[21]; if (x < 0) putchar('-'), x = -x;
do{++len; c[len] = x % 10 + '0';} while (x /= 10);
for (int i = len; i >= 1; i--) putchar(c[i]);
}
template <typename T, typename... Args> inline void write(T& x, Args &... args) { write(x), write(args...); }
template<class T1, class T2> bool umin(T1& a, T2 b) { return a > b ? (a = b, true) : false; }
template<class T1, class T2> bool umax(T1& a, T2 b) { return a < b ? (a = b, true) : false; }
template<class T> void clear(T& a) { T().swap(a); }
const int N = 1e5 + 5;
int n, m, _, k, cas;
int lens, lent, f[N], extend[N];
char s[N], t[N];
void kmp(char* t, int lent) { //t从1开始
int j = 0, k = 2;
while (j + 2 <= lent && t[j + 1] == t[j + 2]) ++j;
f[2] = j; f[1] = lent;
for (int i = 3, p = k + f[k] - 1; i <= lent; ++i, p = k + f[k] - 1)
if (i + f[i - k + 1] - 1 < p) f[i] = f[i - k + 1];
else {
j = max(0, p - i + 1);
while (j + i <= lent && t[j + 1] == t[i + j]) ++j;
f[i] = j; k = i;
}
}
void ex_kmp(char *s, char *t, int lens, int lent) { //s, t下标都是从1开始
int j = 0, k = 1;
while (j + 1 <= min(lens, lent) && s[j + 1] == t[j + 1]) ++j;
extend[1] = j;
for (int i = 2, p = k + extend[k] - 1; i <= lens; ++i, p = k + extend[k] - 1)
if (i + f[i - k + 1] - 1 < p) extend[i] = f[i - k + 1];
else {
j = max(0, p - i + 1);
while (j + i <= lens && j + 1 <= lent && t[j + 1] == s[i + j]) ++j;
extend[i] = j; k = i;
}
}
int main() {
IOS; cin >> t + 1 >> s + 1;
lent = strlen(t + 1); lens = strlen(s + 1);
kmp(t, lent); ex_kmp(s, t, lens, lent);
ll ans = 0;
rep (i, 1, extend[1]) ans += extend[i + 1];
cout << ans;
return 0;
}