问题是numpy把多维数组看作矩阵的栈,最后两个维总是假定为线性空间维。这意味着点积不能通过折叠3d数组的第一个维度来工作。在
相反,你能做的最简单的事情就是把你的3d数组重塑成2d数组,做矩阵乘法,然后再重新整形成3d数组。这也将使用优化的BLAS代码,这是numpy的一大优势。在import numpy as np
S_pinv = np.random.rand(3, 4)
images = np.random.rand(4, 5, 6)
# error:
# (S_pinv @ images).shape
res_shape = S_pinv.shape[:1] + images.shape[1:] # (3, 5, 6)
res = (S_pinv @ images.reshape(images.shape[0], -1)).reshape(res_shape)
print(res.shape) # (3, 5, 6)
所以我们用(3,n) x (n,h,w)代替(3,n) x (n, h*w) -> (3, h*w),我们把它改回(3, h, w)。重塑是免费的,因为这并不意味着对内存中的数据进行任何实际操作(只是对数组下的单个内存块的重新解释),正如我所说的,适当的矩阵乘积在numpy中得到了高度优化。在
既然您要求一种更直观的方法,这里有一种利用^{}的替代方法。它可能会慢一些,但如果你稍微习惯它的符号,它是非常透明的:
^{pr2}$
这个符号命名了输入数组的每个维度:对于S_pinv,第一个和第二个维度分别命名为t和{},类似地,n,h和{}表示{}。输出被设