[python刷题模板] 前缀函数/next数组/kmp算法
一、 算法&数据结构
1. 描述
前缀函数和next数组基本上是一个东西,但储存的内容不同。
他们是kmp算法的基础。但真的不太好理解,以及不好写,背不过。
前缀函数π(i)可以在O(n)的时间计算出来数组内每个前缀的前缀函数。
- 参考 oiwiki前缀函数与 KMP 算法
- kmp还可以结合字典树搞ac自动机,待施工。
- 前缀函数π[i]代表的前缀s[:i+1]和后缀s[-i:]相同的情况下,是前缀长度。
- 简单来说 pi[i] 就是,子串 s[0… i] 最长的相等的真前缀与真后缀的长度。
- next数组是指模式串在i位置匹配失败后,应该向前跳到哪个位置开始继续匹配。
2. 复杂度分析
- 预处理O(n)
- 查询O(n)
3. 常见应用
- 字符串查询。
4. 常用优化
- 从意义上来说,前缀函数值得是前后缀相同的长度;next数组是匹配失败后模式串指针j要去的位置。
- 因此kmp搜索用next数组写法简单点(参考模板代码3);但找前后缀用前缀函数更直观(模板代码1)。
二、 模板代码
1. 裸前缀函数
例题: 4808. 构造字符串
这题暴力能过,但还是前缀函数nb。
# Problem: 构造字符串
# Contest: AcWing
# URL: https://www.acwing.com/problem/content/4811/
# Memory Limit: 256 MB
# Time Limit: 1000 ms
import sys
import bisect
import random
import io, os
from bisect import *
from collections import *
from contextlib import redirect_stdout
from itertools import *
from array import *
from functools import lru_cache
from types import GeneratorType
from heapq import *
from math import sqrt, gcd, inf
if sys.version >= '3.8': # ACW没有comb
from math import comb
RI = lambda: map(int, sys.stdin.buffer.readline().split())
RS = lambda: map(bytes.decode, sys.stdin.buffer.readline().strip().split())
RILST = lambda: list(RI())
DEBUG = lambda *x: sys.stderr.write(f'{str(x)}\n')
MOD = 10 ** 9 + 7
def prefix_function(s):
"""计算s的前缀函数"""
n = len(s)
pi = [0] * n
for i in range(1, n):
j = pi[i - 1]
while j > 0 and s[i] != s[j]:
j = pi[j - 1]
if s[i] == s[j]:
j += 1
pi[i] = j
return pi
# ms
def solve():
n, k = RI()
t, = RS()
mx = prefix_function(t)[-1]
if mx == 0:
return print(t * k)
suf = t[mx:]
print(t + suf * (k - 1))
if __name__ == '__main__':
solve()
2. 树上kmp
链接: 1367. 二叉树中的链表
试了下树上kmp是负优化,但可能是数据问题。
class Solution:
def isSubPath(self, head: ListNode, root: TreeNode) -> bool:
path = []
while head:
path.append(head.val)
head = head.next
n = len(path)
def get_next(p):
n = len(p)
nxt = [0]*n
nxt[0] = -1
j,k=0,-1
while j < n-1:
if k == -1 or p[j] == p[k]:
j+=1
k+=1
if p[j] == p[k]:
nxt[j] = nxt[k]
else:
nxt[j] = k
else:
k = nxt[k]
return nxt
nxt = get_next(path)
# print(nxt)
def dfs_kmp(tree, j):
if j == n:
return True
if not tree:
return False
if j == -1 or tree.val == path[j]:
return dfs_kmp(tree.left,j+1) or dfs_kmp(tree.right,j+1)
else:
return dfs_kmp(tree,nxt[j])
3. 裸kmp
class Kmp:
"""kmp算法,计算前缀函数pi,根据pi转移,复杂度O(m+n)"""
def __init__(self, t):
"""传入模式串,计算前缀函数"""
self.t = t
n = len(t)
self.pi = pi = [0] * n
j = 0
for i in range(1, n):
while j and t[i] != t[j]:
j = pi[j - 1] # 失配后缩短期望匹配长度
if t[i] == t[j]:
j += 1 # 多配一个
pi[i] = j
def find_all_yield(self, s):
"""查找t在s中的所有位置,注意可能为空"""
n, t, pi, j = len(self.t), self.t, self.pi, 0
for i, v in enumerate(s):
while j and v != t[j]:
j = pi[j - 1]
if v == t[j]:
j += 1
if j == n:
yield i - j + 1
j = pi[j - 1]
def find_one(self, s):
"""查找t在s中的第一个位置,如果不存在就返回-1"""
for ans in self.find_all_yield(s):
return ans
return -1
class Solution:
def strStr(self, haystack: str, needle: str) -> int:
m,n = len(haystack),len(needle)
# def get_next(p):
# n = len(p)
# nxt = [-1] * n
# j, k = 0, -1
# while j < n - 1:
# if k == -1 or p[j] == p[k]:
# j += 1
# k += 1
# if p[j] == p[k]:
# nxt[j] = nxt[k]
# else:
# nxt[j] = k
# else:
# k = nxt[k]
# return nxt
# nxt = get_next(needle)
# print(nxt)
# i = j = 0
# while i < m and j < n:
# if j == -1 or haystack[i] == needle[j]:
# i += 1
# j += 1
# else:
# j = nxt[j]
# if j == n:
# return i - j
# return -1
def prefix_function(s):
"""计算s的前缀函数"""
n = len(s)
pi = [0] * n
for i in range(1, n):
j = pi[i - 1]
while j > 0 and s[i] != s[j]:
j = pi[j - 1]
if s[i] == s[j]:
j += 1
pi[i] = j
return pi
pi = prefix_function(needle)
print(pi)
i ,j = 0,0
while i < m and j < n:
while j > 0 and haystack[i] != needle[j]:
j = pi[j-1]
if haystack[i] == needle[j]:
j += 1
if j == n:
return i - j + 1
i += 1
return -1
4. KMP+矩阵快速幂
链接: 2851. 字符串转换
矩阵快速幂,要求矩阵是正方形,可以把线性dp的O(n)优化成O(lgn)
https://leetcode.cn/problems/string-transformation/
"""
MOD = 10 ** 9 + 7
class Kmp:
"""kmp算法,计算前缀函数pi,根据pi转移,复杂度O(m+n)"""
def __init__(self, t):
"""传入模式串,计算前缀函数"""
self.t = t
n = len(t)
self.pi = pi = [0] * n
j = 0
for i in range(1, n):
while j and t[i] != t[j]:
j = pi[j - 1] # 失配后缩短期望匹配长度
if t[i] == t[j]:
j += 1 # 多配一个
pi[i] = j
def find_all_yield(self, s):
"""查找t在s中的所有位置,注意可能为空"""
n, t, pi, j = len(self.t), self.t, self.pi, 0
for i, v in enumerate(s):
while j and v != t[j]:
j = pi[j - 1]
if v == t[j]:
j += 1
if j == n:
yield i - j + 1
j = pi[j - 1]
def find_one(self, s):
"""查找t在s中的第一个位置,如果不存在就返回-1"""
for ans in self.find_all_yield(s):
return ans
return -1
def matrix_multiply(a, b, MOD=10 ** 9 + 7):
m, n, p = len(a), len(a[0]), len(b[0])
ans = [[0] * p for _ in range(m)]
for i in range(m):
for j in range(n):
for k in range(p):
ans[i][k] = (ans[i][k] + a[i][j] * b[j][k]) % MOD
return ans
def matrix_pow_mod(a, b, MOD=10 ** 9 + 7):
n = len(a)
ans = [[0] * n for _ in range(n)]
for i in range(n):
ans[i][i] = 1
while b:
if b & 1:
ans = matrix_multiply(ans, a, MOD)
a = matrix_multiply(a, a, MOD)
b >>= 1
return ans
class Solution:
def numberOfWays(self, s: str, t: str, k: int) -> int:
if t not in s + s:
return 0
n = len(s)
c = len(list(Kmp(t).find_all_yield(s + s[:-1])))
m = [
[c - 1, c],
[n - c, n - 1 - c]
]
m = matrix_pow_mod(m, k)
return m[0][s != t]