[python刷题模板] 前缀函数/next数组/kmp算法

一、 算法&数据结构

1. 描述

前缀函数和next数组基本上是一个东西,但储存的内容不同。
他们是kmp算法的基础。但真的不太好理解,以及不好写,背不过。
前缀函数π(i)可以在O(n)的时间计算出来数组内每个前缀的前缀函数。
  • 参考 oiwiki前缀函数与 KMP 算法
  • kmp还可以结合字典树搞ac自动机,待施工。
  • 前缀函数π[i]代表的前缀s[:i+1]和后缀s[-i:]相同的情况下,是前缀长度。
    • 简单来说 pi[i] 就是,子串 s[0… i] 最长的相等的真前缀与真后缀的长度。
  • next数组是指模式串在i位置匹配失败后,应该向前跳到哪个位置开始继续匹配。

2. 复杂度分析

  1. 预处理O(n)
  2. 查询O(n)

3. 常见应用

  1. 字符串查询。

4. 常用优化

  1. 从意义上来说,前缀函数值得是前后缀相同的长度;next数组是匹配失败后模式串指针j要去的位置。
  • 因此kmp搜索用next数组写法简单点(参考模板代码3);但找前后缀用前缀函数更直观(模板代码1)。

二、 模板代码

1. 裸前缀函数

例题: 4808. 构造字符串
这题暴力能过,但还是前缀函数nb。

# Problem: 构造字符串
# Contest: AcWing
# URL: https://www.acwing.com/problem/content/4811/
# Memory Limit: 256 MB
# Time Limit: 1000 ms

import sys
import bisect
import random
import io, os
from bisect import *
from collections import *
from contextlib import redirect_stdout
from itertools import *
from array import *
from functools import lru_cache
from types import GeneratorType
from heapq import *
from math import sqrt, gcd, inf

if sys.version >= '3.8':  # ACW没有comb
    from math import comb

RI = lambda: map(int, sys.stdin.buffer.readline().split())
RS = lambda: map(bytes.decode, sys.stdin.buffer.readline().strip().split())
RILST = lambda: list(RI())
DEBUG = lambda *x: sys.stderr.write(f'{str(x)}\n')

MOD = 10 ** 9 + 7

def prefix_function(s):
    """计算s的前缀函数"""
    n = len(s)
    pi = [0] * n
    for i in range(1, n):
        j = pi[i - 1]
        while j > 0 and s[i] != s[j]:
            j = pi[j - 1]
        if s[i] == s[j]:
            j += 1
        pi[i] = j
    return pi
#       ms
def solve():
    n, k = RI()
    t, = RS()
    mx = prefix_function(t)[-1]

    if mx == 0:
        return print(t * k)
    suf = t[mx:]
    print(t + suf * (k - 1))


if __name__ == '__main__':
    solve()

2. 树上kmp

链接: 1367. 二叉树中的链表

试了下树上kmp是负优化,但可能是数据问题。

class Solution:
    def isSubPath(self, head: ListNode, root: TreeNode) -> bool:
        path = []
        while head:
            path.append(head.val)
            head = head.next
        n = len(path)
        def get_next(p):
            n = len(p)
            nxt = [0]*n
            nxt[0] = -1
            j,k=0,-1
            while j < n-1:
                if k == -1 or p[j] == p[k]:
                    j+=1
                    k+=1
                    if p[j] == p[k]:
                        nxt[j] = nxt[k]
                    else:
                        nxt[j] = k 
                else:
                    k = nxt[k]
           
            return nxt
        nxt = get_next(path)
        # print(nxt)
        
        def dfs_kmp(tree, j):
            if j == n:
                return True
            if not tree:
                return False
            if j == -1 or tree.val == path[j]:
                return dfs_kmp(tree.left,j+1) or dfs_kmp(tree.right,j+1)
            else:
                return dfs_kmp(tree,nxt[j]) 

3. 裸kmp

链接: 28. 找出字符串中第一个匹配项的下标

class Kmp:
    """kmp算法,计算前缀函数pi,根据pi转移,复杂度O(m+n)"""

    def __init__(self, t):
        """传入模式串,计算前缀函数"""
        self.t = t
        n = len(t)
        self.pi = pi = [0] * n
        j = 0
        for i in range(1, n):
            while j and t[i] != t[j]:
                j = pi[j - 1]  # 失配后缩短期望匹配长度
            if t[i] == t[j]:
                j += 1  # 多配一个
            pi[i] = j

    def find_all_yield(self, s):
        """查找t在s中的所有位置,注意可能为空"""
        n, t, pi, j = len(self.t), self.t, self.pi, 0
        for i, v in enumerate(s):
            while j and v != t[j]:
                j = pi[j - 1]
            if v == t[j]:
                j += 1
            if j == n:
                yield i - j + 1
                j = pi[j - 1]

    def find_one(self, s):
        """查找t在s中的第一个位置,如果不存在就返回-1"""
        for ans in self.find_all_yield(s):
            return ans
        return -1

class Solution:
    def strStr(self, haystack: str, needle: str) -> int:
        m,n = len(haystack),len(needle)
        
        # def get_next(p):
        #     n = len(p)
        #     nxt = [-1] * n
        #     j, k = 0, -1
        #     while j < n - 1:
        #         if k == -1 or p[j] == p[k]:
        #             j += 1
        #             k += 1
        #             if p[j] == p[k]:
        #                 nxt[j] = nxt[k]
        #             else:
        #                 nxt[j] = k
        #         else:
        #             k = nxt[k]

        #     return nxt
            
        # nxt = get_next(needle)
        # print(nxt)

        # i = j = 0        
        # while i < m and j < n:
        #     if j == -1 or haystack[i] == needle[j]:
        #         i += 1
        #         j += 1
        #     else:
        #         j = nxt[j]
        # if j == n:
        #     return i - j 
        # return -1

        
        def prefix_function(s):
            """计算s的前缀函数"""
            n = len(s)
            pi = [0] * n
            for i in range(1, n):
                j = pi[i - 1]
                while j > 0 and s[i] != s[j]:
                    j = pi[j - 1]
                if s[i] == s[j]:
                    j += 1
                pi[i] = j
            return pi
            
        pi = prefix_function(needle)
        print(pi)

        i ,j = 0,0        
        while i < m and j < n:
            while  j > 0 and haystack[i] != needle[j]:
                j = pi[j-1]
            if haystack[i] == needle[j]:               
                j += 1
            if j == n:
                return i - j + 1
            
            i += 1
        return -1

4. KMP+矩阵快速幂

链接: 2851. 字符串转换

矩阵快速幂,要求矩阵是正方形,可以把线性dp的O(n)优化成O(lgn)
https://leetcode.cn/problems/string-transformation/
"""
MOD = 10 ** 9 + 7


class Kmp:
    """kmp算法,计算前缀函数pi,根据pi转移,复杂度O(m+n)"""

    def __init__(self, t):
        """传入模式串,计算前缀函数"""
        self.t = t
        n = len(t)
        self.pi = pi = [0] * n
        j = 0
        for i in range(1, n):
            while j and t[i] != t[j]:
                j = pi[j - 1]  # 失配后缩短期望匹配长度
            if t[i] == t[j]:
                j += 1  # 多配一个
            pi[i] = j

    def find_all_yield(self, s):
        """查找t在s中的所有位置,注意可能为空"""
        n, t, pi, j = len(self.t), self.t, self.pi, 0
        for i, v in enumerate(s):
            while j and v != t[j]:
                j = pi[j - 1]
            if v == t[j]:
                j += 1
            if j == n:
                yield i - j + 1
                j = pi[j - 1]

    def find_one(self, s):
        """查找t在s中的第一个位置,如果不存在就返回-1"""
        for ans in self.find_all_yield(s):
            return ans
        return -1


def matrix_multiply(a, b, MOD=10 ** 9 + 7):
    m, n, p = len(a), len(a[0]), len(b[0])
    ans = [[0] * p for _ in range(m)]
    for i in range(m):
        for j in range(n):
            for k in range(p):
                ans[i][k] = (ans[i][k] + a[i][j] * b[j][k]) % MOD
    return ans


def matrix_pow_mod(a, b, MOD=10 ** 9 + 7):
    n = len(a)
    ans = [[0] * n for _ in range(n)]
    for i in range(n):
        ans[i][i] = 1
    while b:
        if b & 1:
            ans = matrix_multiply(ans, a, MOD)
        a = matrix_multiply(a, a, MOD)
        b >>= 1
    return ans


class Solution:
    def numberOfWays(self, s: str, t: str, k: int) -> int:
        if t not in s + s:
            return 0
        n = len(s)
        c = len(list(Kmp(t).find_all_yield(s + s[:-1])))
        m = [
            [c - 1, c],
            [n - c, n - 1 - c]
        ]
        m = matrix_pow_mod(m, k)
        return m[0][s != t]

三、其他

四、更多例题

五、参考链接

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值