[python刷题模板] 矩阵快速幂(手写/numpy)

一、 算法&数据结构

1. 描述

矩阵快速幂是一种采用数学办法降低复杂度的操作。(其实就是把递推公式变成通项公式)
  • 如果一个递推公式可以写成诸如正方形矩阵的形式,且每个相邻递推公式系数不变,那么可以提出系数变成指数写法.
    • 记为: f[i] = m*f[i-1],且m不随i变化。
    • 那么可以推出f[n] = mn * f[0]。
  • 我们可以用快速幂计算m^n,这样本需要递推计算n次的操作,就降低到了log(n)次。

  • 若f[i]每一项只有一项,那么就是上述计算过程,非常好理解,一下就推出了通项公式。
  • 但若f[i][0/1/2…]这种二维矩阵怎么办呢?其实是一样的.
  • 可以将f[i]展开写成一列的矩阵,形如:
    • f[i][0] = c00 * f[i-1][0] + c01 * f[i-1][1] …
    • f[i][1] = c10 * f[i-1][0] + c11 * f[i-1][1] …
    • 即:
    • [ f [ i ] [ 0 ] f [ i ] [ 1 ] . . ] = [ c 00 c 01 . . c 10 c 11 . . . . ] ∗ [ f [ i − 1 ] [ 0 ] f [ i − 1 ] [ 1 ] . . ] \left[\begin {array}{c} f[i][0] \\ f[i][1] \\ .. \\ \end{array}\right] = \left[\begin {array}{c} c00 &c01 &.. \\ c10 &c11 &.. \\ .. \\ \end{array}\right] *\left[\begin {array}{c} f[i-1][0] \\ f[i-1][1] \\ .. \\ \end{array}\right] f[i][0]f[i][1].. = c00c10..c01c11.... f[i1][0]f[i1][1]..
    • 可推出:
    • [ f [ n ] [ 0 ] f [ n ] [ 1 ] . . ] = [ c 00 c 01 . . c 10 c 11 . . . . ] n ∗ [ f [ 0 ] [ 0 ] f [ 0 ] [ 1 ] . . ] \left[\begin {array}{c} f[n][0] \\ f[n][1] \\ .. \\ \end{array}\right] = \left[\begin {array}{c} c00 &c01 &.. \\ c10 &c11 &.. \\ .. \\ \end{array}\right]^n * \left[\begin {array}{c} f[0][0] \\ f[0][1] \\ .. \\ \end{array}\right] f[n][0]f[n][1].. = c00c10..c01c11.... n f[0][0]f[0][1]..
  • 于是,我们只需要知道f[0]和矩阵m,就相当于知道了通项公式。

  • 要注意的是,这需要保证一件事,就是上边提到的相邻递推公式系数不变
  • 换句话说,在递推过程中,每层的转移方法是固定的,不随着当前层的/下标等因素发生if等特殊改动(这通常是状态机)。
  • 这才能保证可以提出系数,而且需要矩阵是正方形,才能做平方操作。
    • 因此这种题可能首先要写普通dp,列出转移方程,再优化。
  • 注意,手写快速幂矩阵时,np.eye()(对角线是1)的矩阵,相当于数字里的1。
  • 另外,推完f[n],通常要考虑最终答案是什么,可能是fn[i],也可能是sum(fn),等。

2. 复杂度分析

  1. 把本来 O(n) 的递推操作,优化成了 O(logn) 的数学计算。

3. 常见应用

  1. 状态转移只依赖上层的线性DP(通常是状态机)。
  2. 状态转移只依赖上2/3…层的线性DP,采用错位写法。

4. 常用优化

  1. 注意很有可能需要特判n比较小的情况,可以避免很多wa,尤其是错位写法。
  2. 由于系数矩阵需要是正方形,要注意补0。
  3. fib这种计算,注意错位补0技巧。
  4. 利用numpy库省去手写矩阵乘法的过程.

import numpy as np
if n == 1:
    return 5
m = np.mat([
    [0, 1, 1, 0, 1],
    [1, 0, 1, 0, 0],
    [0, 1, 0, 1, 0],
    [0, 0, 1, 0, 0],
    [0, 0, 1, 1, 0]
])
f0 = np.mat([
    [1],
    [1],
    [1],
    [1],
    [1],
])
n -= 1
while n:
    if n & 1:
        f0 = m * f0 % MOD
    m = m * m % MOD
    n >>= 1
return int(f0.sum()) % MOD

二、 模板代码

1. 斐波那契数列(错位写矩阵,手写矩阵乘法)

例题: 509. 斐波那契数

def matrix_multiply(a, b, MOD=10 ** 9 + 7):
    m, n, p = len(a), len(a[0]), len(b[0])
    ans = [[0] * p for _ in range(m)]
    for i in range(m):
        for j in range(n):
            for k in range(p):
                ans[i][k] = (ans[i][k] + a[i][j] * b[j][k])
    return ans


def matrix_pow_mod(a, b, MOD=10 ** 9 + 7):
    n = len(a)
    ans = [[0] * n for _ in range(n)]
    for i in range(n):
        ans[i][i] = 1
    while b:
        if b & 1:
            ans = matrix_multiply(ans, a, MOD)
        a = matrix_multiply(a, a, MOD)
        b >>= 1
    return ans


class Solution:
    def fib(self, n: int) -> int:
        if n == 0:
            return 0
        m = [[1, 1], [1, 0]]
        return matrix_pow_mod(m, n - 1)[0][0]

2. 1137. 第 N 个泰波那契数(错位写矩阵,手写矩阵乘法)

链接: 1137. 第 N 个泰波那契数

def matrix_multiply(a, b, MOD=10 ** 9 + 7):
    m, n, p = len(a), len(a[0]), len(b[0])
    ans = [[0] * p for _ in range(m)]
    for i in range(m):
        for j in range(n):
            for k in range(p):
                ans[i][k] = (ans[i][k] + a[i][j] * b[j][k])
    return ans


def matrix_pow_mod(a, b, MOD=10 ** 9 + 7):
    n = len(a)
    ans = [[0] * n for _ in range(n)]
    for i in range(n):
        ans[i][i] = 1
    while b:
        if b & 1:
            ans = matrix_multiply(ans, a, MOD)
        a = matrix_multiply(a, a, MOD)
        b >>= 1
    return ans


class Solution:
    def fib(self, n: int) -> int:
        if n == 0:
            return 0
        m = [[1, 1], [1, 0]]
        return matrix_pow_mod(m, n - 1)[0][0]

3. 1220. 统计元音字母序列的数目(状态机DP,用numpy)

链接: 1220. 统计元音字母序列的数目

MOD = 10**9+7

import numpy  as np
class Solution:
    def countVowelPermutation(self, n: int) -> int:
        """
        定义f[i][0~4]表示长为i+1的字符串,最后结尾是aeiou的种类数
        显然f[0][0:5] = [1,1,1,1,1]
        下边g = f[i-1]
        f[i][0] = 0g[0] + 1g[1] + 1g[2] + 0g[3] + 1g[4]
        f[i][1] = 1g[0] + 0g[1] + 1g[2] + 0g[3] + 0g[4]
        f[i][2] = 0g[0] + 1g[1] + 0g[2] + 1g[3] + 0g[4]
        f[i][3] = 0g[0] + 0g[1] + 1g[2] + 0g[3] + 0g[4]
        f[i][4] = 0g[0] + 0g[1] + 1g[2] + 1g[3] + 0g[4]
        """
        if n == 1:
            return 5
        m = np.mat([
            [0,1,1,0,1],
            [1,0,1,0,0],
            [0,1,0,1,0],
            [0,0,1,0,0],
            [0,0,1,1,0]
        ])
        f0 = np.mat([
            [1],
            [1],
            [1],
            [1],
            [1],
        ])
        n -= 1
        while n:
            if n &1:
                f0 = m*f0 %MOD 
            m = m*m%MOD 
            n >>= 1
        return int(f0.sum()) %MOD

4. 552. 学生出勤记录 II(2维状态机DP展开成1维,用numpy)

链接: 552. 学生出勤记录 II

  • 这题状态维度多一层,还好是2*3,可以直接展开。
MOD = 10 ** 9 + 7
import numpy as np
class Solution:
    def checkRecord(self, n: int) -> int:
        '''
        f[i][0/1][0/1/2]表示i天,A=0,1,最近连续L为0/1/2时的情况
        '''
        # f = [[0]*3 for _ in range(2)]
        # f[0][0] = 1
        # for _ in range(n):
        #     g = [[0]*3 for _ in range(2)]
        #     g[0][0] = f[0][0] + f[0][1] + f[0][2]
        #     g[0][1] = f[0][0] 
        #     g[0][2] = f[0][1]
        #     g[1][0] = f[0][0] + f[0][1] + f[0][2] + f[1][0] + f[1][1] + f[1][2]
        #     g[1][1] = f[1][0]
        #     g[1][2] = f[1][1]
        #     for i in range(2):
        #         for j in range(3):
        #             g[i][j] %= MOD 
        #     f = g 
        # return sum(sum(row) for row in f) %MOD
        '''
        0 0:0
        0 1:1
        0 2:2
        1 0:3
        1 1:4
        1 2:5
        f[i][0] = [1, 1, 1, 0, 0, 0]
        f[i][1] = [1, 0, 0, 0, 0, 0]
        f[i][2] = [0, 1, 0, 0, 0, 0]
        f[i][3] = [1, 1, 1, 1, 1, 1]
        f[i][4] = [0, 0, 0, 1, 0, 0]
        f[i][5] = [0, 0, 0, 0, 1, 0]
        f[n] = m^n * [[1],[0],[0],[0],[0],[0]]
        '''
        f0 = np.mat([[1],[0],[0],[0],[0],[0]])
        m = np.mat([
            [1, 1, 1, 0, 0, 0],
            [1, 0, 0, 0, 0, 0],
            [0, 1, 0, 0, 0, 0],
            [1, 1, 1, 1, 1, 1],
            [0, 0, 0, 1, 0, 0],
            [0, 0, 0, 0, 1, 0],
        ])
        while n:
            if n & 1:
                f0 = m * f0 %MOD 
            m = m * m %MOD 
            n >>= 1
        return int(f0.sum()%MOD)       

5. 2851. 字符串转换(KMP+矩阵快速幂)

链接: 2851. 字符串转换

  • 难点在于想到kmp,以及状态转移。
MOD = 10 ** 9 + 7
class Kmp:
    """kmp算法,计算前缀函数pi,根据pi转移,复杂度O(m+n)"""

    def __init__(self, t):
        """传入模式串,计算前缀函数"""
        self.t = t
        n = len(t)
        self.pi = pi = [0] * n
        j = 0
        for i in range(1, n):
            while j and t[i] != t[j]:
                j = pi[j - 1]  # 失配后缩短期望匹配长度
            if t[i] == t[j]:
                j += 1  # 多配一个
            pi[i] = j

    def find_all_yield(self, s):
        """查找t在s中的所有位置,注意可能为空"""
        n, t, pi, j = len(self.t), self.t, self.pi, 0        
        for i, v in enumerate(s):
            while j and v != t[j]:
                j = pi[j - 1]
            if v == t[j]:
                j += 1
            if j == n:
                yield i - j + 1
                j = pi[j - 1]

    def find_one(self, s):
        """查找t在s中的第一个位置,如果不存在就返回-1"""
        for ans in self.find_all_yield(s):
            return ans
        return -1
def matrix_multiply(a, b, MOD=10**9+7):
    m, n, p = len(a), len(a[0]), len(b[0])
    ans = [[0]*p for _ in range(m)]
    for i in range(m):
        for j in range(n):
            for k in range(p):
                ans[i][k] = (ans[i][k]+a[i][j] * b[j][k]) %MOD
    return ans 
def matrix_pow_mod(a, b, MOD=10**9+7):
    n = len(a)
    ans = [[0]*n for _ in range(n)]
    for i in range(n):
        ans[i][i] = 1 
    while b:
        if b & 1:
            ans = matrix_multiply(ans, a, MOD)
        a = matrix_multiply(a, a, MOD)
        b >>= 1
    return ans


class Solution:
    def numberOfWays(self, s: str, t: str, k: int) -> int:
        if t not in s+s:
            return 0
        n = len(s)
        c = len(list(Kmp(t).find_all_yield(s+s[:-1])))        
        m = [
            [c-1, c],
            [n-c,n-1-c]
        ]
        m = matrix_pow_mod(m, k)
        return m[0][s != t]    

三、其他

  1. 一定别忘了取模。
  2. 通常小数据特判。

四、更多例题

五、参考链接

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值