python 三维矩阵乘以二维矩阵_如何将一个二维数组和一个三维数组矩阵相乘得到一个三维数组?...

在Python中,使用numpy处理三维矩阵与二维矩阵相乘时,不能直接通过折叠3d数组的第一个维度进行点积。推荐的做法是先将3d数组重塑为2d数组,执行矩阵乘法后再恢复为3d数组。这种方法利用了numpy的优化BLAS代码,提高了效率。此外,还介绍了使用einsum函数的更直观但可能较慢的方法。通过测试,发现直接重塑方法比转置后再计算更快。
摘要由CSDN通过智能技术生成

问题是numpy把多维数组看作矩阵的栈,最后两个维总是假定为线性空间维。这意味着点积不能通过折叠3d数组的第一个维度来工作。在

相反,你能做的最简单的事情就是把你的3d数组重塑成2d数组,做矩阵乘法,然后再重新整形成3d数组。这也将使用优化的BLAS代码,这是numpy的一大优势。在import numpy as np

S_pinv = np.random.rand(3, 4)

images = np.random.rand(4, 5, 6)

# error:

# (S_pinv @ images).shape

res_shape = S_pinv.shape[:1] + images.shape[1:] # (3, 5, 6)

res = (S_pinv @ images.reshape(images.shape[0], -1)).reshape(res_shape)

print(res.shape) # (3, 5, 6)

所以我们用(3,n) x (n,h,w)代替(3,n) x (n, h*w) -> (3, h*w),我们把它改回(3, h, w)。重塑是免费的,因为这并不意味着对内存中的数据进行任何实际操作(只是对数组下的单个内存块的重新解释),正如我所说的,适当的矩阵乘积在numpy中得到了高度优化。在

既然您要求一种更直观的方法,这里有一种利用^{}的替代方法。它可能会慢一些,但如果你稍微习惯它的符号,它是非常透明的:

^{pr2}$

这个符号命名了输入数组的每个维度:对于S_pinv,第一个和第二个维度分别命名为t和{},类似地,n,h和{}表示{}。输出被设

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值