上一篇博客总结了numpy中常用的乘法函数:
numpy常用乘法函数总结:np.dot()、np.multiply()、*、np.matmul()、@、np.prod()、np.outer()-CSDN博客
主要是 np.dot()、np.multiply()、*、np.matmul()、@ 五种,其中 np.matmul() 和 @ 完全等价,np.multiply() 和 * 在输入数据类型为 np.array 时也完全等价
本文总结Pytorch常用的乘法函数,TensorFlow常用乘法函数放在下一篇文章里
目录
torch.mul() 和 *【等价,element-wise乘,可广播】
torch.mm() 或 torch.bmm() 【矩阵乘法,前二维后三维,均不可广播】
@【等价于 torch.dot() + torch.mv() + torch.mm()】
torch.mul() 和 *【等价,element-wise乘,可广播】
torch.mul(x, y) 等价于 x*y
- 矩阵 * 标量:矩阵中的每个元素都乘以标量
- 矩阵 * 行向量:要求矩阵的列数 = 行向量的列数
- 矩阵 * 列向量:要求矩阵的行数 = 列向量的行数
- 矩阵 * 矩阵:要求两个矩阵的维度完全相同,不同的话看是否能广播成相同shape
vec1 = torch.arange(4)
vec2 = torch.tensor([4,3,2,1])
mat1 = torch.arange(12).reshape(4,3)
mat2 = torch.arange(12).reshape(3,4)
print(vec1 * vec2)
print(mat2 * vec1)
print(mat1 * mat1)
Output:
tensor([0, 3, 4, 3])
tensor([[ 0, 1, 4, 9],
[ 0, 5, 12, 21],
[ 0, 9, 20, 33]])
tensor([[ 0, 1, 4],
[ 9, 16, 25],
[ 36, 49, 64],
[ 81, 100, 121]])
一句话总结:如果两者shape相同,直接element-wise相乘;如果两者shape不同,本质上是先把另一个扩展/复制成相同的shape(即广播操作),再对应元素相乘
广播操作
(1)在一定的规则下允许高维Tensor和低维Tensor之间的运算。这里举一个例子:a是二维Tensor,b是三维Tensor,但是a的维度与b的后两位相同,那么a和b仍然可以做 * 操作,结果是一个和b维度一样的三维Tensor。可以理解为沿着b的第0维做二维Tensor点积,或运算前将a沿着b的第0维进行了expand操作
import torch
a = torch.tensor([[1, 2], [2, 3]]) # torch.Size([2, 2])
b = torch.tensor([[[1, 2], [2, 3]], [[-1, -2], [-2, -3]]]) # torch.Size([2, 2, 2])
print(a*b) # torch.Size([2, 2, 2]) b*a 结果相同
'''
tensor([[[ 1, 4],
[ 4, 9]],
[[-1, -4],
[-4, -9]]])
'''
运算过程可以理解为:
a = a.expand(b.size())
'''
tensor([[[1, 2],
[2, 3]],
[[1, 2],
[2, 3]]])
'''
a*b
(2)再举个例子:两个Tensor都是3维,但shape不一致,这种情况看是否能够广播成一致的shape
import torch
node = torch.tensor([[1, 1, 1, 0, 0], [1, 1, 1, 1, 1]]) # shape=(2,5)
node_0 = node.unsqueeze(-1) # torch.Size([2, 5, 1])
'''
tensor([[[1],
[1],
[1],
[0],
[0]],
[[1],
[1],
[1],
[1],