kmp算法,看了知乎的回答,前面都看懂了,关键如何获取next数组没理解太透彻。顺着前面的思路,自己想明白了。
https://www.zhihu.com/question/21923021/answer/281346746
首先明白部分匹配表(pmt):PMT中的值是字符串的前缀集合与后缀集合的交集中最长元素的长度。
上一张图:
需要明白pmt和next的关系,pmt数组整体向右移动一格,然后截掉最后一位,第一位用0补全就会得到next数组。所以获取next其实就是获取pmt。核心就变成了如何获取目标字符串的pmt,至于获取pmt的函数get_part_match_table,一两句话说不清,思路就是两个目标字符串上下做滑动窗口的操作,逐步获取数组的值,相当于又一次kmp操作
完成代码如下,已用leetcode这道题验证:
https://leetcode-cn.com/problems/implement-strstr/
# -*- coding: utf-8 -*-
def get_pmt(target_str):
"""
获取部分匹配表
例如:
target_str = 'aaabaa'
part_table = [0,1,2,0,1,2]
:param target_str: 要匹配的字符串
:return:
"""
i = 1
j = 0
part_table = [0] * len(target_str)
while i < len(target_str):
if target_str[i] == target_str[j]:
part_table[i] = j + 1
j += 1
i += 1
else:
if j == 0:
i += 1
continue
j = part_table[j - 1]
# if j == 0:
# i += 1
# print(part_table)
return part_table
def get_next(target_str):
"""
获取next表
:param target_str: 要匹配的字符串
:return:
"""
part_table = get_pmt(target_str)
next = [0] + part_table[:-1]
return next
def kmp(source_str, target_str):
i = j = 0
next = get_next(target_str)
# i + len(target_str) - j <= len(source_str)这句代码判断断滑动窗口是否越界
# 下面这种情况,i和j都不越界,但是滑动窗口以及越界了
# aabaabaaa
# aabaaa
while i < len(source_str) and j < len(target_str) and i + len(target_str) - j <= len(source_str):
if source_str[i] == target_str[j]:
i += 1
j += 1
else:
if j == 0:
i += 1
j = next[j]
if j == len(target_str):
return i - j
return -1
if __name__ == '__main__':
source_str = "aabaabaaa"
target_str = 'aabaaa'
position = kmp(source_str, target_str)
print(position)