pytorch中提供了 matmul、mm和bmm等矩阵的乘法运算功能,但其具体计算细节和场景截然不同,应予以注意和区别。
1. torch.mm
该函数即为矩阵的乘法,torch.mm(tensor1, tenor2),参与计算的两个张量必须为二维张量(即矩阵),其结果张量out的维度关系满足:
o
u
t
(
p
×
q
)
=
t
1
(
p
×
m
)
∗
t
2
(
m
×
q
)
out(p\times q)=t_1(p\times m)*t_2(m\times q)
out(p×q)=t1(p×m)∗t2(m×q)
2. torch.bmm
该函数提供了batch维度的矩阵乘法运算,即batch内对应的矩阵两两相乘,因此要求参与计算的两个张量必须为三维张量,其中第一个维度为batch_size,必须相同,其结果张量 out的维度关系满足:
o
u
t
(
b
×
p
×
q
)
=
t
1
(
b
×
p
×
m
)
∗
t
2
(
b
×
m
×
q
)
out(b\times p \times q)=t_1(b\times p\times m)*t_2(b\times m\times q)
out(b×p×q)=t1(b×p×m)∗t2(b×m×q)
3. torch.matmul
该函数的功能相较于前面两个要丰富的多,其支持参与运算的两个张量有不同的维度,计算的机理也各不相同,具体包括:
(1) 两个张量均为1维张量(即向量)时,计算向量的内积
(2) 两个张量均为2维张量(即矩阵)时,计算矩阵的乘法
(3) 第一个向量为1维张量,第二个张量为2维张量,对第一个张量视情进行broadcast,然后进行矩阵乘法,在将上述结果squeeze为1维;
(4) 第二个向量为1维张量,第一个张量为2维张量,计算矩阵和向量的乘法;
(5) 两个向量维度至少为1,且其中至少一个张量的维度大于2;则先进行broadcast,然后进行bmm操作,最后将上述结果转换会broadcast之前的效果。
4851

被折叠的 条评论
为什么被折叠?



