pytroch中乘法,大致分为两类:
- 函数乘法
torch.mul(a,b)
- 对象调用的方法
a.mul(b)
本文按照使用频率排序,梳理了相关乘法的使用场景。
1. 运算符 *
- 若矩阵*单个数,则element-wise
- 矩阵*向量,符合广播规则
广播规则
torch.Tensor(4,3)*torch.Tensor(4)
#若矩阵*一维向量,则列对齐(列数相等)
torch.Tensor(4,3)*torch.Tensor(3,1)
#若矩阵*NX1维向量,则行对齐(行数相等)
- 矩阵*矩阵,只能element-wise
2.矩阵逐元素(Element-wise)乘法 torch.mul()
torch.mul(mat1, other, out=None)
其中 other 乘数可以是标量,也可以是任意维度的矩阵, 只要满足最终相乘满足广播规则即可。
大致和*相似
3.多维矩阵乘法 torch.matmul()
torch.matmul(input, other, out=None)
支持broadcast操作,使用起来比较复杂。针对多维数据 matmul() 乘法,可以认为该乘法使用使用两个参数的后两个维度来计算,其他的维度都可以认为是batch维度。
4.作为神经元的nn.Linear()
self.L=nn.Linear(10,5)
x=torch.Tensor(100,10)
self.L(x)# 输出维度100,5
相当于矩阵相乘: x × s l e f . L x\times slef.L x×slef.L
5.二维矩阵乘法 torch.mm()
torch.mm(mat1, mat2, out=None)
其中
m
a
t
1
∈
R
n
×
m
,
m
a
t
2
∈
R
m
×
d
,
o
u
t
∈
R
n
×
d
mat1\in R^{n\times m},mat2\in R^{m\times d},out\in R^{n\times d}
mat1∈Rn×m,mat2∈Rm×d,out∈Rn×d
该函数一般只用来计算两个二维矩阵的矩阵乘法,并且不支持broadcast操作。
6.三维矩阵乘法 torch.bmm()
torch.bmm(mat1, mat2, out=None)
其中
m
a
t
1
∈
R
b
×
n
×
m
,
m
a
t
2
∈
R
b
×
m
×
d
,
o
u
t
∈
R
b
×
n
×
d
mat1\in R^{b\times n\times m},mat2\in R^{b\times m\times d},out\in R^{b\times n\times d}
mat1∈Rb×n×m,mat2∈Rb×m×d,out∈Rb×n×d
主要是多了外层batch(一组多个样本同时训练),该函数的两个输入必须是三维矩阵并且第一维相同(表示Batch维度), 不支持broadcast操作
7.矩阵逐元素(Element-wise)乘法 torch.mul()
torch.mul(mat1, other, out=None)
其中 other 乘数可以是标量,也可以是任意维度的矩阵, 只要满足最终相乘是可以broadcast的即可。
8.einsum
爱因斯坦求和??太过于稀有,暂时跳过