[python刷题模板] 矩阵快速幂 (手写/numpy
一、 算法&数据结构
1. 描述
矩阵快速幂是一种采用数学办法降低复杂度的操作。(其实就是把递推公式变成通项公式)
- 如果一个递推公式可以写成诸如正方形矩阵的形式,且每个相邻递推公式系数不变,那么可以提出系数变成指数写法.
- 记为: f[i] = m*f[i-1],且m不随i变化。
- 那么可以推出f[n] = mn * f[0]。
- 我们可以用快速幂计算
m^n
,这样本需要递推计算n次的操作,就降低到了log(n)次。
- 若f[i]每一项只有一项,那么就是上述计算过程,非常好理解,一下就推出了通项公式。
- 但若f[i][0/1/2…]这种二维矩阵怎么办呢?其实是一样的.
- 可以将f[i]展开写成一列的矩阵,形如:
- f[i][0] = c00 * f[i-1][0] + c01 * f[i-1][1] …
- f[i][1] = c10 * f[i-1][0] + c11 * f[i-1][1] …
- …
- 即:
- [ f [ i ] [ 0 ] f [ i ] [ 1 ] . . ] = [ c 00 c 01 . . c 10 c 11 . . . . ] ∗ [ f [ i − 1 ] [ 0 ] f [ i − 1 ] [ 1 ] . . ] \left[\begin {array}{c} f[i][0] \\ f[i][1] \\ .. \\ \end{array}\right] = \left[\begin {array}{c} c00 &c01 &.. \\ c10 &c11 &.. \\ .. \\ \end{array}\right] *\left[\begin {array}{c} f[i-1][0] \\ f[i-1][1] \\ .. \\ \end{array}\right] f[i][0]f[i][1].. = c00c10..c01c11.... ∗ f[i−1][0]f[i−1][1]..
- 可推出:
- [ f [ n ] [ 0 ] f [ n ] [ 1 ] . . ] = [ c 00 c 01 . . c 10 c 11 . . . . ] n ∗ [ f [ 0 ] [ 0 ] f [ 0 ] [ 1 ] . . ] \left[\begin {array}{c} f[n][0] \\ f[n][1] \\ .. \\ \end{array}\right] = \left[\begin {array}{c} c00 &c01 &.. \\ c10 &c11 &.. \\ .. \\ \end{array}\right]^n * \left[\begin {array}{c} f[0][0] \\ f[0][1] \\ .. \\ \end{array}\right] f[n][0]f[n][1].. = c00c10..c01c11.... n∗ f[0][0]f[0][1]..
- 于是,我们只需要知道f[0]和矩阵m,就相当于知道了通项公式。
- 要注意的是,这需要
保证
一件事,就是上边提到的相邻递推公式系数不变。 - 换句话说,在递推过程中,每层的转移方法是固定的,不随着当前层的
值
/下标
等因素发生if
等特殊改动(这通常是状态机)。 - 这才能保证可以提出系数,而且需要矩阵是正方形,才能做平方操作。
- 因此这种题可能首先要写普通dp,列出转移方程,再优化。
- 注意,手写快速幂矩阵时,np.eye()(对角线是1)的矩阵,相当于数字里的1。
- 另外,推完f[n],通常要考虑最终答案是什么,可能是fn[i],也可能是sum(fn),等。
2. 复杂度分析
- 把本来 O(n) 的递推操作,优化成了 O(logn) 的数学计算。
3. 常见应用
- 状态转移只依赖上层的线性DP(通常是状态机)。
- 状态转移只依赖上2/3…层的线性DP,采用错位写法。
4. 常用优化
import numpy as np
if n == 1:
return 5
m = np.mat([
[0, 1, 1, 0, 1],
[1, 0, 1, 0, 0],
[0, 1, 0, 1, 0],
[0, 0, 1, 0, 0],
[0, 0, 1, 1, 0]
])
f0 = np.mat([
[1],
[1],
[1],
[1],
[1],
])
n -= 1
while n:
if n & 1:
f0 = m * f0 % MOD
m = m * m % MOD
n >>= 1
return int(f0.sum()) % MOD
二、 模板代码
1. 斐波那契数列(错位写矩阵,手写矩阵乘法)
例题: 509. 斐波那契数
def matrix_multiply(a, b, MOD=10 ** 9 + 7):
m, n, p = len(a), len(a[0]), len(b[0])
ans = [[0] * p for _ in range(m)]
for i in range(m):
for j in range(n):
for k in range(p):
ans[i][k] = (ans[i][k] + a[i][j] * b[j][k])
return ans
def matrix_pow_mod(a, b, MOD=10 ** 9 + 7):
n = len(a)
ans = [[0] * n for _ in range(n)]
for i in range(n):
ans[i][i] = 1
while b:
if b & 1:
ans = matrix_multiply(ans, a, MOD)
a = matrix_multiply(a, a, MOD)
b >>= 1
return ans
class Solution:
def fib(self, n: int) -> int:
if n == 0:
return 0
m = [[1, 1], [1, 0]]
return matrix_pow_mod(m, n - 1)[0][0]
2. 1137. 第 N 个泰波那契数(错位写矩阵,手写矩阵乘法)
链接: 1137. 第 N 个泰波那契数
def matrix_multiply(a, b, MOD=10 ** 9 + 7):
m, n, p = len(a), len(a[0]), len(b[0])
ans = [[0] * p for _ in range(m)]
for i in range(m):
for j in range(n):
for k in range(p):
ans[i][k] = (ans[i][k] + a[i][j] * b[j][k])
return ans
def matrix_pow_mod(a, b, MOD=10 ** 9 + 7):
n = len(a)
ans = [[0] * n for _ in range(n)]
for i in range(n):
ans[i][i] = 1
while b:
if b & 1:
ans = matrix_multiply(ans, a, MOD)
a = matrix_multiply(a, a, MOD)
b >>= 1
return ans
class Solution:
def fib(self, n: int) -> int:
if n == 0:
return 0
m = [[1, 1], [1, 0]]
return matrix_pow_mod(m, n - 1)[0][0]
3. 1220. 统计元音字母序列的数目(状态机DP,用numpy)
MOD = 10**9+7
import numpy as np
class Solution:
def countVowelPermutation(self, n: int) -> int:
"""
定义f[i][0~4]表示长为i+1的字符串,最后结尾是aeiou的种类数
显然f[0][0:5] = [1,1,1,1,1]
下边g = f[i-1]
f[i][0] = 0g[0] + 1g[1] + 1g[2] + 0g[3] + 1g[4]
f[i][1] = 1g[0] + 0g[1] + 1g[2] + 0g[3] + 0g[4]
f[i][2] = 0g[0] + 1g[1] + 0g[2] + 1g[3] + 0g[4]
f[i][3] = 0g[0] + 0g[1] + 1g[2] + 0g[3] + 0g[4]
f[i][4] = 0g[0] + 0g[1] + 1g[2] + 1g[3] + 0g[4]
"""
if n == 1:
return 5
m = np.mat([
[0,1,1,0,1],
[1,0,1,0,0],
[0,1,0,1,0],
[0,0,1,0,0],
[0,0,1,1,0]
])
f0 = np.mat([
[1],
[1],
[1],
[1],
[1],
])
n -= 1
while n:
if n &1:
f0 = m*f0 %MOD
m = m*m%MOD
n >>= 1
return int(f0.sum()) %MOD
4. 552. 学生出勤记录 II(2维状态机DP展开成1维,用numpy)
链接: 552. 学生出勤记录 II
- 这题状态维度多一层,还好是2*3,可以直接展开。
MOD = 10 ** 9 + 7
import numpy as np
class Solution:
def checkRecord(self, n: int) -> int:
'''
f[i][0/1][0/1/2]表示i天,A=0,1,最近连续L为0/1/2时的情况
'''
# f = [[0]*3 for _ in range(2)]
# f[0][0] = 1
# for _ in range(n):
# g = [[0]*3 for _ in range(2)]
# g[0][0] = f[0][0] + f[0][1] + f[0][2]
# g[0][1] = f[0][0]
# g[0][2] = f[0][1]
# g[1][0] = f[0][0] + f[0][1] + f[0][2] + f[1][0] + f[1][1] + f[1][2]
# g[1][1] = f[1][0]
# g[1][2] = f[1][1]
# for i in range(2):
# for j in range(3):
# g[i][j] %= MOD
# f = g
# return sum(sum(row) for row in f) %MOD
'''
0 0:0
0 1:1
0 2:2
1 0:3
1 1:4
1 2:5
f[i][0] = [1, 1, 1, 0, 0, 0]
f[i][1] = [1, 0, 0, 0, 0, 0]
f[i][2] = [0, 1, 0, 0, 0, 0]
f[i][3] = [1, 1, 1, 1, 1, 1]
f[i][4] = [0, 0, 0, 1, 0, 0]
f[i][5] = [0, 0, 0, 0, 1, 0]
f[n] = m^n * [[1],[0],[0],[0],[0],[0]]
'''
f0 = np.mat([[1],[0],[0],[0],[0],[0]])
m = np.mat([
[1, 1, 1, 0, 0, 0],
[1, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 1, 0],
])
while n:
if n & 1:
f0 = m * f0 %MOD
m = m * m %MOD
n >>= 1
return int(f0.sum()%MOD)
5. 2851. 字符串转换(KMP+矩阵快速幂)
链接: 2851. 字符串转换
- 难点在于想到kmp,以及状态转移。
MOD = 10 ** 9 + 7
class Kmp:
"""kmp算法,计算前缀函数pi,根据pi转移,复杂度O(m+n)"""
def __init__(self, t):
"""传入模式串,计算前缀函数"""
self.t = t
n = len(t)
self.pi = pi = [0] * n
j = 0
for i in range(1, n):
while j and t[i] != t[j]:
j = pi[j - 1] # 失配后缩短期望匹配长度
if t[i] == t[j]:
j += 1 # 多配一个
pi[i] = j
def find_all_yield(self, s):
"""查找t在s中的所有位置,注意可能为空"""
n, t, pi, j = len(self.t), self.t, self.pi, 0
for i, v in enumerate(s):
while j and v != t[j]:
j = pi[j - 1]
if v == t[j]:
j += 1
if j == n:
yield i - j + 1
j = pi[j - 1]
def find_one(self, s):
"""查找t在s中的第一个位置,如果不存在就返回-1"""
for ans in self.find_all_yield(s):
return ans
return -1
def matrix_multiply(a, b, MOD=10**9+7):
m, n, p = len(a), len(a[0]), len(b[0])
ans = [[0]*p for _ in range(m)]
for i in range(m):
for j in range(n):
for k in range(p):
ans[i][k] = (ans[i][k]+a[i][j] * b[j][k]) %MOD
return ans
def matrix_pow_mod(a, b, MOD=10**9+7):
n = len(a)
ans = [[0]*n for _ in range(n)]
for i in range(n):
ans[i][i] = 1
while b:
if b & 1:
ans = matrix_multiply(ans, a, MOD)
a = matrix_multiply(a, a, MOD)
b >>= 1
return ans
class Solution:
def numberOfWays(self, s: str, t: str, k: int) -> int:
if t not in s+s:
return 0
n = len(s)
c = len(list(Kmp(t).find_all_yield(s+s[:-1])))
m = [
[c-1, c],
[n-c,n-1-c]
]
m = matrix_pow_mod(m, k)
return m[0][s != t]
三、其他
- 一定别忘了取模。
- 通常小数据特判。