torch.mm()
二维矩阵的乘法,假设输入矩阵mat1
维度是
(
m
×
n
)
(m×n)
(m×n),矩阵mat2
维度是
(
n
×
p
)
(n×p)
(n×p),则输出维度为
(
m
×
p
)
(m×p)
(m×p),只能是二维的
mat1 = torch.randn(2, 3)
mat2 = torch.randn(3, 3)
out = torch.mm(mat1, mat2)
'''
tensor([[ 0.4851, 0.5037, -0.3633],
[-0.0760, -3.6705, 2.4784]])
'''
torch.sparse.mm()
a是稀疏矩阵,b是稀疏矩阵或者密集矩阵,sparse.mm
的作用和torch.mm
一样,都是做矩阵乘法计算
a = torch.randn(2, 3).to_sparse().requires_grad_(True)
b = torch.randn(3, 2, requires_grad=True)
y = torch.sparse.mm(a, b)
'''
a:
tensor(indices=tensor([[0, 0, 0, 1, 1, 1],
[0, 1, 2, 0, 1, 2]]),
values=tensor([ 1.5901, 0.0183, -0.6146, 1.8061, -0.0112, 0.6302]),
size=(2, 3), nnz=6, layout=torch.sparse_coo, requires_grad=True)
b:
tensor([[-0.6479, 0.7874],
[-1.2056, 0.5641],
[-1.1716, -0.9923]], requires_grad=True)
y:
tensor([[-0.3323, 1.8723],
[-1.8951, 0.7904]], grad_fn=<SparseAddmmBackward>)
'''
torch.bmm()
与torch.mm
类似,但多了一个batch_size
维度,输入矩阵张量mat1
维度是
(
b
×
m
×
n
)
(b×m×n)
(b×m×n),矩阵张量mat2
维度是
(
b
×
n
×
p
)
(b×n×p)
(b×n×p),则输出维度为
(
b
×
m
×
p
)
(b×m×p)
(b×m×p)
mat1 = torch.randn(10, 3, 4)
mat2 = torch.randn(10, 4, 5)
res = torch.bmm(mat1, mat2)
print(res.size())
# torch.Size([10, 3, 5])
torch.mul()
将输入张量input
的每个元素与另一个标量other
相乘,返回一个新的张量out
,两者维度需满足广播规则
# 方式1:张量 和 标量相乘
a = torch.randn(3)
torch.mul(a, 100)
'''
a: tensor([ 0.2015, -0.4255, 2.6087])
tensor([ 20.1494, -42.5491, 260.8663])
'''
# 方式2:张量 和 张量(需满足广播规则)
a = torch.randn(4, 1)
b = torch.randn(1, 4)
c = torch.mul(a,b)
'''
c:
tensor([[-0.1183, -0.4246, -0.0512, 0.1757],
[-0.4215, -1.5121, -0.1823, 0.6257],
[-0.0358, -0.1284, -0.0155, 0.0531],
[ 0.1649, 0.5917, 0.0713, -0.2448]])
'''
# 方式3:元素对应项相乘
a = torch.randn(3, 2)
b = torch.randn(3, 2)
c = torch.mul(a,b)
'''
C:
tensor([[-1.9259, -0.0116],
[-1.8523, -0.0392],
[-0.4881, -0.4235]])
'''
torch.matmul()
两个张量的矩阵乘积。其行为取决于张量的维数如下:
- 如果两个张量都是一维的,则返回点积(标量)。
- 如果两个参数都是二维的,则返回矩阵-矩阵乘积。
- 如果第一个参数是一维的,第二个参数是二维的,则在其维数前加一个1,以实现矩阵乘法。在矩阵相乘之后,附加的维度被删除。
- 如果第一个参数是二维的,第二个参数是一维的,则返回矩阵向量乘积。
- 如果两个参数至少是一维的,且至少一个参数是N维的(其中N > 2),则返回一个批处理矩阵乘法。如果第一个参数是一维的,则在其维数前加上1,以便批处理矩阵相乘,然后删除。如果第二个参数是一维的,则为批处理矩阵倍数的目的,将在其维上追加一个1,然后删除它。非矩阵(即批处理)维度是广播的(因此必须是可广播的)
# vector x vector
tensor1 = torch.randn(3)
tensor2 = torch.randn(3)
torch.matmul(tensor1, tensor2).size()
torch.Size([])
# matrix x vector
tensor1 = torch.randn(3, 4)
tensor2 = torch.randn(4)
torch.matmul(tensor1, tensor2).size()
torch.Size([3])
# batched matrix x broadcasted vector
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4)
torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3])
# batched matrix x batched matrix
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(10, 4, 5)
torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3, 5])
# batched matrix x broadcasted matrix
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4, 5)
torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3, 5])
总结
- 二维矩阵乘积用
torch.mm()
或torch.sparse.mm()
- 多批次的二维矩阵之间的乘积用
torch.bmm()
- 标量乘积或对应项乘积用
torch.mul()
- 批次或广播进行乘积用
torch.matmul()