KMP算法的介绍参见维基百科:
https://en.wikipedia.org/wiki/Knuth%E2%80%93Morris%E2%80%93Pratt_algorithm
这篇文章的解释不错:
http://www.ruanyifeng.com/blog/2013/05/Knuth%E2%80%93Morris%E2%80%93Pratt_algorithm.html
算法的关键在于next数组的计算,相当于自己跟自己进行匹配。
算法的时间复杂度为O(m+n)。
这张图有助于对算法的理解:
(图片来自http://wiki.jikexueyuan.com/project/kmp-algorithm/define.html)
算法导论的伪代码:
C++代码:(与算法导论稍有不同,当然next数组还有可以优化的地方)
int strStr(string haystack, string needle) {
int hlen = haystack.size();
int nlen = needle.size();
if(nlen == 0) {
return 0;
}
if(hlen == 0) {
return -1;
}
vector<int> pattern(nlen);
GeneratePattern(needle, pattern);
return Match(haystack, needle, pattern);
}
void GeneratePattern(const string &str, vector<int> &pattern) {
int len = str.size();
pattern[0] = -1;
int j = 1;
int k = -1;
while(j < len) {
if(k == -1 || str[j - 1] == str[k]) {
k++;
pattern[j] = k;
j++;
} else {
k = pattern[k];
}
}
}
int Match(const string &haystack, const string &needle, const vector<int> &pattern) {
int hlen = haystack.size();
int nlen = needle.size();
int j = 0;
int k = 0; // not -1
while(j < hlen) {
if(k == -1 || haystack[j] == needle[k]) {
j++;
k++;
if(k == nlen) {
return j - k;
}
} else {
k = pattern[k];
}
}
return -1;
}
Golang代码:
package main
import (
"fmt"
)
func main() {
fmt.Println(StrStr("abaa", "aa"))
}
func StrStr(haystack string, needle string) int {
m := len(haystack)
n := len(needle)
if n == 0 {
return 0
}
if m == 0 {
return -1
}
pattern := make([]int, n)
GeneratePattern(needle, pattern)
return Match(haystack, needle, pattern)
}
func GeneratePattern(str string, pattern []int) {
length := len(str)
pattern[0] = -1
j := 1
k := -1
for j < length {
if k == -1 || str[j - 1] == str[k] {
k++
pattern[j] = k
j++
} else {
k = pattern[k]
}
}
}
func Match(haystack string, needle string, pattern []int) int {
hlen := len(haystack)
nlen := len(needle)
j := 0
k := 0
for j < hlen {
if k == -1 || haystack[j] == needle[k] {
j++
k++
if k == nlen {
return j - k
}
} else {
k = pattern[k]
}
}
return -1
}
Python代码:
def strStr(haystack, needle):
"""
:type haystack: str
:type needle: str
:rtype: int
"""
hlen = len(haystack)
nlen = len(needle)
if nlen == 0:
return 0
if hlen == 0:
return -1
pattern = [0] * nlen
GeneratePattern(needle, pattern)
return Match(haystack, needle, pattern)
def GeneratePattern(string, pattern):
length = len(string)
pattern[0] = -1
j = 1
k = -1
while j < length:
if k == -1 or string[j - 1] == string[k]:
k += 1
pattern[j] = k
j += 1
else:
k = pattern[k]
def Match(haystack, needle, pattern):
hlen = len(haystack)
nlen = len(needle)
j = 0
k = 0
while j < hlen:
if k == -1 or haystack[j] == needle[k]:
j += 1
k += 1
if k == nlen:
return j - k
else:
k = pattern[k]
return -1
if __name__ == '__main__':
print strStr("abaa", "aa")