python矩阵乘法算法_Python算法|矩阵链乘法

概述

矩阵乘法是一个满足结合律的运算。显然,对于矩阵A、B、C来说,(AB)C 与 A(BC) 是等价的,我们可以根据自己的心情选择任意的运算顺序,总之,结果都是一样的。

糟糕的是,对计算机来说可不是这么回事,若我们假定矩阵 A=[10,20], B=[20,30], C=[30,40],那么在以下两种运算顺序中,标量相乘的次数是天差地别:

(AB)C = 10*20*30 + 10*30*40 = 18000

A(BC) = 20*30*40 + 10*20*40 = 32000

为了计算表达式,我们可以先用括号明确计算次序,然后利用标准的矩阵相乘算法进行计算。

完全括号化(fully parenthesized):它是单一矩阵,或者是两个完全括号化的矩阵乘积链的积。

例如如果有矩阵链为,则共有5种完全括号化的矩阵乘积链。

(A1(A2(A3A4)))

(A1((A2A3)A4))

((A1A2)(A3A4))

((A1(A2A3))A4)

((A1(A2A3))A4)

对矩阵链加括号的方式会对乘积运算的代价产生巨大影响。我们先来分析两个矩阵相乘的代价。下面的伪代码的给出了两个矩阵相乘的标准算法,属性rows和columns是矩阵的行数和列数。

MATRIX-MULTIPKLY(A,B)

if A.columns≠B.rows

error "incompatible dimensions"

else let C be a new A.rows×B.columns matrix

for i = 1 to A.rows

for j = 1 to B.columns

c(ij)=0

for k = 1 to A.columns

c(ij)=c(ij)+a(ik)*b(kj)

return C

两个矩阵A和B只有相容(compatible),即A的列数等于B的行数时,才能相乘。如果A是p×q的矩阵,B是q×r的矩阵,那么乘积C是p×r的矩阵。计算C所需要时间由第8行的标量乘法的次数决定的,即pqr。

以矩阵链为例,来说明不同的加括号方式会导致不同的计算代价。假设三个矩阵的规模分别为10×100、100×5和5×50。

如果按照((A1A2)A3)的顺序计算,为计算A1A2(规模10×5),需要做10*100*5=5000次标量乘法,再与A3相乘又需要做10*5*50=2500次标量乘法,共需7500次标量乘法。

如果按照(A1(A2A3))的顺序计算,为计算A2A3(规模100×50),需100*5*50=25000次标量乘法,再与A1相乘又需10*100*50=50000次标量乘法,共需75000次标量乘法。因此第一种顺序计算要比第二种顺序计算快10倍。

矩阵链乘法问题(matrix-chain multiplication problem)可描述如下:给定n个矩阵的链,矩阵Ai的规模为p(i-1)×p(i) (1<=i<=n),求完全括号化方案,使得计算乘积A1A2...An所需标量乘法次数最少。

因为括号方案的数量与n呈指数关系,所以通过暴力搜索穷尽所有可能的括号化方案来寻找最优方案是一个糟糕策略。我们可以使用递归关系来找到我们需要的最优解法,首先,我们要用一个函数MCM来得到最小标量相乘次数,那么MCM也可用来定义在所有情况下的最优子段。再使用动态规划和备忘录法即可得到结果。

应用动态规划方法

下面用动态规划方法来求解矩阵链的最优括号方案,我们还是按照之前提出的4个步骤进行:

1.刻画一个最优解的结构特征

2.递归地定义最优解的值

3.计算最优解的值,通常采用自底向上的方法

4.利用计算出的信息构造一个最优解

算法实现

def mult(chain):

n = len(chain)

# single matrix chain has zero cost

aux = {(i, i): (0,) + chain[i] for i in range(n)}

print(aux)

# i: length of subchain(子链)

for i in range(1, n):

# j: starting index of subchain

for j in range(0, n - i):

best = float('inf') #inf is infinite(无穷大)

# k: splitting point of subchain

for k in range(j, j + i):

# multiply subchains at splitting point

lcost, lname, lrow, lcol = aux[j, k]

rcost, rname, rrow, rcol = aux[k + 1, j + i]

cost = lcost + rcost + lrow * lcol * rcol

var = '(%s%s)' % (lname, rname)

print(cost, var)

# pick the best one

if cost < best:

best = cost

aux[j, j + i] = cost, var, lrow, rcol

print(aux)

return dict(zip(['cost', 'order', 'rows', 'cols'], aux[0, n - 1]))

结果

{(0, 0): (0, 'A', 10, 20),

(1, 1): (0, 'B', 20, 30),

(2, 2): (0, 'C', 30, 40)}

6000 (AB)

{(0, 0): (0, 'A', 10, 20),

(1, 1): (0, 'B', 20, 30),

(2, 2): (0, 'C', 30, 40),

(0, 1): (6000, '(AB)', 10, 30)}

24000 (BC)

{(0, 0): (0, 'A', 10, 20),

(1, 1): (0, 'B', 20, 30),

(2, 2): (0, 'C', 30, 40),

(0, 1): (6000, '(AB)', 10, 30),

(1, 2): (24000, '(BC)', 20, 40)}

32000 (A(BC))

{(0, 0): (0, 'A', 10, 20),

(1, 1): (0, 'B', 20, 30),

(2, 2): (0, 'C', 30, 40),

(0, 1): (6000, '(AB)', 10, 30),

(1, 2): (24000, '(BC)', 20, 40),

(0, 2): (32000, '(A(BC))', 10, 40)}

18000 ((AB)C)

{(0, 0): (0, 'A', 10, 20),

(1, 1): (0, 'B', 20, 30),

(2, 2): (0, 'C', 30, 40),

(0, 1): (6000, '(AB)', 10, 30),

(1, 2): (24000, '(BC)', 20, 40),

(0, 2): (18000, '((AB)C)', 10, 40)}

{'cost': 18000, 'order': '((AB)C)', 'rows': 10, 'cols': 40}

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值