1、基础
张量维度:维度个数和维度大小;.ndim可查看维度个数,.shape可查看维度大小。
如下代码,张量a:维度个数为2,是一个2维张量;维度大小为[2,3],即第0维的维度大小为2,第1维为3。
>>> a=torch.arange(8).reshape(2,4)
>>> a
tensor([[0, 1, 2, 3],
[4, 5, 6, 7]])
>>> a.ndim
2
>>> a.shape
torch.Size([2, 4])
>>> a.size()
torch.Size([2, 4])
2、torch.matmul
这是最复杂也是功能最强大的乘法函数 ,可实现混合矩阵乘法。
参数:
-
input
(张量):第1个张量。 -
other
(张量):第2个张量。 -
out
(张量):结果张量,等同于torch.matmul函数的返回值。
返回:
-
张量。
CASE:
(1)若2个张量皆为1维张量,即向量点积,等价于torch.dot函数。结果为scalar标量。
(2)若2个张量皆为2维张量,即矩阵乘法,等价于torch.mm函数。结果为2维张量。
(3)若第1个张量为1维张量,假设维度为[k],第二个张量为2维张量,假设维度为[k,p]。第一个张量会在左边进行维度扩展,维度变为[1,k],然后再进行矩阵乘法,获得维度为[1,p]的张量,然后再去掉扩展的维度,最后结果张量维度为[p]。
(4)若第1个张量为2维张量,假设维度为[k,n],第2个张量为1维张量,假设维度为[n]。第2个张量会在右边进行维度扩展,维度变为[n,1],然后再执行矩阵乘法,获得维度为[k,1]的张量,最后再去掉扩展的维度,获得维度为[k]的结果张量。
(5)如果2个张量的维度均至少为1,且其中至少一个张量维度大于2,那么matmul将执行批矩阵乘法操作:默认使用两个张量的后两维度执行矩阵乘法,其他维度作为batch维。这个复杂,举个例子:
若两个张量均为3维张量,矩阵个数相等(第0维大小相等)且后两维满足矩阵乘法约束,那么调用matmul等价于调用torch.bmm函数
import torch
a=torch.arange(27).reshape(3,3,3)
print(a.size())
print(a.ndim)
b=torch.arange(1,10).reshape(3,3,1)
print(b.size())
print(b.ndim)
c1=torch.matmul(a,b)
c2=torch.bmm(a,b)
c1.equal(c2)
>>torch.Size([3, 3, 3])
>>3
>>torch.Size([3, 3, 1])
>>3
>>True
可能出现torch.bmm无法执行,但torch.matmul仍可执行:两个张量的后两维需满足矩阵乘法约束,不满足的情形会进行维度扩展([2]>>[2,1]),其他维则会通过广播操作([2]>>[2,1]>>[2,1,2,1])对齐。
a=torch.arange(12).reshape(2,1,3,2)
b=torch.arange(2)
c=torch.matmul(a,b)
print(a.ndim,b.ndim,c.ndim)
print(a.shape)
print(b.shape)
print(c.shape)
>>4 1 3
>>torch.Size([2, 1, 3, 2])
>>torch.Size([2])
>>torch.Size([2, 1, 3])
3、torch.dot
功能:向量点积。
参数:
-
input
(张量):第一个张量。 -
other
(张量):第二个张量。 -
out
(张量):结果张量,等同于dot函数的返回值。
返回值:张量(标量)。
重点:只支持具有相同元素个数的两个一维张量做点积操作。
4、torch.mm
功能:矩阵乘法。
参数:
input(张量):第一个矩阵,即2维张量。
mat2(张量):第二个矩阵,即2维张量。
out(张量):结果张量。
重点:torch.mm不会进行广播操作,它严格要求两个张量满足维度约束。即,假设两个张量分别为a、b,要求a.size()[1]=b.size()[0]。
4、torch.bmm
功能:批量矩阵乘法。
参数:
input(张量):第一批矩阵,即3维张量,第0维表示批大小。
mat2(张量):第二批矩阵,即3维张量,第0维表示批大小。
out(张量):结果张量,等同于torch.bmm函数返回值。也是3维张量,第0维表示批大小。
返回值:三维张量,第0维表示批大小。
重点:bmm不会进行广播操作,它严格要求两个张量均为三维张量,且第0维大小相等(表示有多少个矩阵),其他两维满足矩阵乘法约束。 即,假设两个张量分别为a、b,要求a.size()[0]=b.size()[0]且a.size()[-1]==b.size()[1]。
5、torch.mul
功能:逐元素相乘。
参数:
input(张量):第一个张量。
other(张量):第二个张量。
out(张量):结果张量,等同于mul函数的返回值。
返回值:张量。
重点:要求两个张量维度相同,即a.size()==b.size();若不同,则通过广播操作将相乘的两个张量的维度变得相同。同时,它的广播操作还会将两个张量类型统一。
6、总结
# 向量点积运算。要求输入为一维张量且类型相同、元素个数相同,输出为scaler标量。
torch.dot
# 矩阵乘法运算,不支持广播操作。要求输入为二维张量且类型相同,维度大小满足矩阵乘法约束。
torch.mm
# 批矩阵乘法运算,不支持广播操作。要求输入为三维张量且类型相同,第0维大小相等,后两维大小满足矩阵乘法约束。
torch.bmm
# 混合矩阵乘法运算,包括向量点积、矩阵乘法、批矩阵乘法,且支持广播操作(仅针对维度)。要求输入张量类型相同,具体行为根据维度可以分五种情况。
torch.matmul
# 逐元素乘法,等价于*,支持广播操作(包括维度及类型)。无特殊要求或约束。
torch.mul