今天刷LeetCode时看到一道easy的题目,也就是lc.28: 实现 strStr(),我顿时好家伙,这个应该能做的出来~吧。
开始想的是暴力解法,但是发现时间复杂度是O(n*m)的,虽然能AC,但是太慢了,不太好。然后就想用栈去写,发现总有几个案例过不去,改了半天,心态崩了。最后去看题解,好家伙,人家一句代码(haystack.indexOf(needle))就结束了-_- 。后面看到别人说可以用KMP去解,然后就学习了下KMP。话说,学完KMP算法后发现,要用KMP算法的题都算是easy题了嘛,太哈人了。
好了,回归重点:
1、KMP算法的介绍:
网上好多,我就不复制粘贴了,大伙们可以百度一下。
总之,对于字符串的匹配问题(字符串长度为n,待匹配字符串长度为m,n<=m),用暴力的方法去解的话时间复杂度是O(n*m)的,而用KMP算法去解的话,时间复杂度仅有O(n+m),也就是O(n)了。两者的不同之处就在于,KMP算法是通过维护一个next数组,对暴力方法中的回退机制进行了改进,以此来减少字符匹配的次数的,使得可以利用更少的次数找到匹配的字符串,这有点像利用空间去换时间的方式。
2、next数组的计算与代码实现讲解:
要实现KMP算法,最为重要的就是算出next数组,next数组中存储的是字符串的最长相同前后缀的长度,前后缀不清楚的话可以百度下,就是不包括其另一个边界元素所构成的字符串,如aab,则后缀只能是b、ab,不能包括最前面的a,即aab;前缀也只能是a、aa,不包括aab。
2.1 next数组的求解如下:
设字符串Text为" aabaabaaf “,长度为n,
待匹配字符串pattern为” aabaaf ",长度为m。
(由于next数组的计算只与pattern的长度有关,因此,求出next数组的时间复杂度仅为O(m)。)
则pattern的子串有:" a “、” aa “、” aab “、” aaba “、” aabaa "、 " aabaaf ",分别计算其最长相同前后缀的长度后可得:
" a " -> 为0,因为单一个a既可以说是前缀也可以说是后缀,没有相同部分,所以通常设为0
" aa " -> 为1,前缀为a,后缀为a
" aab " -> 为0
" aaba " -> 为1,前缀为a,后缀为a
" aabaa " -> 为2,前缀为aa,后缀为aa,记得后缀是从前面开始算的,如aab是前缀,而对应的后缀则是baa,不是aab
" aabaaf " -> 为0
故最后算得next=[0, 1, 0, 1, 2, 0]
其实,目前网上有好几种表达next数组的形式,①如对上述算得的next数组整体减1,②又或者对上述算得的整体next数组右移一位,第一位设为-1等等,但其最终的原理都是一样的,只是代码的实现方式有点区别。像整体减1的,最终会在代码中再加回1,其目的也是为了凑出最前面的-1;而将next数组整体右移的,也是为了凑出开头的-1。
即next数组可以表现为:①next=[0, 1, 0, 1, 2, 0] / ②next=[-1, 0, -1, 0, 1, -1] / ③next=[-1, 0, 1, 0, 1, 2]。在本文中,主要是对第①种和第③种的next数组进行实现。
其实如果是手算的话,第③种的next数组可能会好算些,而且这种格式在后面的代码种也方便些,可以直接根据下标回退跳转就行。开始设next[0] = -1,然后next第二位就填字符串第二位前所有字符串(但不包括不包括第二位字符)的最长相同前后缀的长度即可,如aab,则next[0] = -1;next[1]就看第二个a前面的字符串最长相同前后缀,即看a,所以next[1] = 0;然后就是next[3],看第三位的前面字符串,即aa,则next[3] = 1,最后最长的字符串就不用算了,可以少算一个,最终next数组即为[-1, 0, 1]。
2.2 next数组代码实现的解释
我按照第③种next数组的形式来进行解释
不知道这张图的解释能不能让你明白,就如图中char中斜对应next数组的这样,比如你算下标为2的next数组值,其实你是对下标为2前的字符串计算得出的。而当两个字符不相等时,则就根据next表进行回退,回退后再判断字符是否相同(注:回退是个连续的过程,所以要用while来进行回退;且回退是有边界的,当回退到0了,则不能再回退了)。
而至于为什么指针 j 不是回退到1而要回退到next[j],则与你之前构建的next数组值有关,因为当前字符不相等是吧,那在这一串的字符串中最大的相同前后缀长度必然不是这么长的,那就退回去,重新找一个短一点的相同的前后缀。当然了,最惨的就是回溯到0了,相当于加上当前这个字符,就不存在相同的前后缀了。
2.3 next数组运用C++代码的表现形式如下:
2.3.1 按原位置生成的next数组
/*
* Func: GetNextArr1
* 按原位置生成next数组
* 例如:aaab -> [0,1,2,0]
*/
void GetNextArr1(vector<int>& next, string pattern) {
int j = 0, len = next.size();
for (int i = 1; i < len; i++) {
// 回退到上一个
while (j > 0 && pattern[i] != pattern[j]) {
j = next[j - 1];
}
if (pattern[i] == pattern[j]) {
next[i] = j++;
}
}
}
2.3.2 向右平移一个位置生成next数组,其中next[0] = -1
void GetNextArr2(vector<int>& next, string pattern) {
// 初始位为-1,即next[-1, ... ...]
next[0] = -1;
int j = 0, len = next.size(); // j设为前缀末尾
for (int i = 1; i < len; i++) {
next[i] = j;
// 回退
while (j > 0 && pattern[i] != pattern[j]) {
j = next[j];
}
if (pattern[i] == pattern[j]) {
j++;
}
}
}
3、KMP算法的代码实现:
我已经将两种生成next数组的方法及其对应的KMP算法实现都写在对应的类里了。如对于next数组的计算,就在GetNext类里;其对应的KMP算法则写在KMP类里了,其中:
GetNextArr1 对应 KMP1_1 与 KMP1_2
GetNextArr2 对应 KMP2_1 与 KMP2_2
同一KMP算法的实现(如 KMP1_1 与 KMP1_2)仅在代码的输出实现有所区别。就比如KMP1_1函数是找到第一组匹配的字符串就返回其对应的开头位置,不对剩下的Text字符串继续寻找是否还存在匹配的字符串;而KMP1_2的代码则是找到所有符合的字符串开头位置。
对于代码的字符串输入,读取字符串的函数是与LeetCode是一样的,所有按照LeetCode的字符串输入格式来输入就可以了
3.1 运行效果:
3.2 全部的实现代码(Github:https://github.com/DeepVegChicken/Learning-KMP_Algorithm):
#include<iostream>
#include<vector>
#include<string>
#include<sstream>
using namespace std;
class GetNext {
public:
/*
* Func: GetNextArr1
* 按原位置生成next数组
* 例如:aaab -> [0,1,2,0]
*/
void GetNextArr1(vector<int>& next, string pattern) {
int n = next.size();
for (int i = 1, j = 0; i < n; i++) {
// 回退
while (j > 0 && pattern[i] != pattern[j]) {
j = next[j - 1];
}
if (pattern[i] == pattern[j]) {
next[i] = j++;
}
}
}
/*
* Func: GetNextArr2
* 向右平移一个生成next数组,其中next[0] = -1
* 例如:aaab -> [-1,0,1,2]
*/
void GetNextArr2(vector<int>& next, string pattern) {
// 初始位为-1,即next[-1, ... ...]
next[0] = -1;
int j = 0, len = next.size(); // j设为前缀末尾
for (int i = 1; i < len; i++) {
next[i] = j;
// 回退
while (j > 0 && pattern[i] != pattern[j]) {
j = next[j];
}
if (pattern[i] == pattern[j]) {
j++;
}
}
}
};
class KMP {
public:
/*
* Func: KMP1_1
* 按原位置生成next数组的KMP算法实现
* 只找出第一组匹配的字符串
* Return: 匹配失败返回-1,成功则返回头位置
*/
int KMP1_1(string text, string pattern, vector<int> &next) {
int j = 0;
int m = text.size(), n = pattern.size();
for (int i = 0; i < m; i++) {
while (j > 0 && text[i] != pattern[j]) {
j = next[j - 1];
}
if (text[i] == pattern[j]) {
j++;
}
// 匹配成功
if (j == n) {
return i - n + 1;
}
}
return -1;
}
/*
* Func: KMP1_2
* 按原位置生成next数组的KMP算法实现
* 可找出多组匹配的字符串
*/
void KMP1_2(string text, string pattern, vector<int>& next, vector<int>& retStarArr) {
// 获取next数组
int j = 0;
int m = text.size(), n = pattern.size();
for (int i = 0, j = 0; i < m; i++) {
while (j > 0 && text[i] != pattern[j]) {
j = next[j - 1];
}
if (text[i] == pattern[j]) {
j++;
}
// 匹配成功
if (j == n) {
retStarArr.push_back(i - n + 1);
j = 0;
}
}
}
/*
* Func: KMP2_1
* next数组向右平移一格的KMP算法实现
* 只找出第一组匹配的字符串
* Return: 匹配失败返回-1,成功则返回头位置
*/
int KMP2_1(string text, string pattern, vector<int>& next) {
int j = 0;
int m = text.size(), n = pattern.size();
for (int i = 0; i < m; i++) {
while (j > 0 && text[i] != pattern[j]) {
j = next[j];
}
if (text[i] == pattern[j]) {
j++;
}
// 匹配成功
if (j == n) {
return i - n + 1;
}
}
return -1;
}
/*
* Func: KMP2_2
* next数组向右平移一格的KMP算法实现
* 可找出多组匹配的字符串
*/
void KMP2_2(string text, string pattern, vector<int>& next, vector<int>& retStarArr) {
int j = 0;
int m = text.size(), n = pattern.size();
for (int i = 0; i < m; i++) {
while (j > 0 && text[i] != pattern[j]) {
j = next[j];
}
if (text[i] == pattern[j]) {
j++;
}
// 匹配成功
if (j == n) {
retStarArr.push_back(i - n + 1);
j = 0;
}
}
}
};
string stringToString(string input) {
string result;
for (int i = 1; i < input.length() - 1; i++) {
char currentChar = input[i];
if (input[i] == '\\') {
char nextChar = input[i + 1];
switch (nextChar) {
case '\"': result.push_back('\"'); break;
case '/': result.push_back('/'); break;
case '\\': result.push_back('\\'); break;
case 'b': result.push_back('\b'); break;
case 'f': result.push_back('\f'); break;
case 'r': result.push_back('\r'); break;
case 'n': result.push_back('\n'); break;
case 't': result.push_back('\t'); break;
default: break;
}
i++;
}
else {
result.push_back(currentChar);
}
}
return result;
}
string boolToString(bool input) {
return input ? "True" : "False";
}
void Printf(vector<int> &v) {
for (auto x : v) {
cout << x << " ";
}
cout << endl;
}
int main() {
string line;
while (getline(cin, line)) {
string text = stringToString(line);
getline(cin, line);
string pattern = stringToString(line);
if (pattern.empty()) {
return 0;
}
if (text.empty()) {
return -1;
}
// 字符串长度
int tLen = text.size(), pLen = pattern.size();
/*
* 按原位置生成的next数组
*/
cout << "按原位置生成next数组: " << endl;
vector<int> next1(pLen), retArr1;
GetNext().GetNextArr1(next1, pattern);
Printf(next1); // 验证next数组生成是否正确
// ①
int num1 = KMP().KMP1_1(text, pattern, next1);
cout << num1 << endl;
// ②
KMP().KMP1_2(text, pattern, next1, retArr1);
if (retArr1.empty()) {
// 因为返回-1后编译器直接就结束程序了
// 所以当遇到不匹配的时候返回-1时,想要测试下面的方法就得单独测试
return -1;
}
else {
Printf(retArr1);
}
cout << endl;
/*
* 向右平移一个生成的next数组
*/
cout << "向右平移一个生成next数组: " << endl;
vector<int> next2(pLen), retArr2;
GetNext().GetNextArr2(next2, pattern);
Printf(next2); // 验证next数组生成是否正确
// ①
int num2 = KMP().KMP2_1(text, pattern, next2);
cout << num2 << endl;
// ②
KMP().KMP2_2(text, pattern, next2, retArr2);
if (retArr2.empty()) {
return -1;
}
else {
Printf(retArr2);
}
}
return 0;
}