torch.mul
torch.mul(input, other, out=None)
功能:
对位相乘,可以广播
该函数能处理两种情况:
- input是矩阵/向量,other是标量
这个时候是就是input的所有元素乘上other - input是矩阵/向量,other是矩阵/向量
这时 o u t i = i n p u t i × o t h e r i out_i = input_i \times other_i outi=inputi×otheri,对位相乘,如果两个都是向量,则可以广播的
例子:
-
input和other的size相同的对位相乘
a: tensor([[ 1.8351, 2.1536], [-0.8320, -1.4578]]) b: tensor([[2.9355, 0.3450], [0.5708, 1.9957]]) c = torch.mul(a,b): tensor([[ 5.3869, 0.7429], [-0.4749, -2.9093]])
-
两个向量的广播
a: tensor([[ 1.8351, 2.1536], [-0.8320, -1.4578]]) b: tensor([[2.9355, 0.3450], [0.5708, 1.9957]]) c = torch.mul(a,b): tensor([[ 5.3869, 0.7429], [-0.4749, -2.9093]])
torch.mm
torch.mm(input, mat2, out=None)
解决的问题:
处理二维矩阵的乘法,而且也只能处理二维矩阵,其他维度要用torch.matmul
例子:
mat1 = torch.randn(2, 3)
mat2 = torch.randn(3, 3)
torch.mm(mat1, mat2)
tensor([[ 0.4851, 0.5037, -0.3633],
[-0.0760, -3.6705, 2.4784]])
torch.bmm
torch.bmm(input, mat2, out=None)
看函数名就知道,在torch.mm
的基础上加了个batch计算,不能广播
torch.matmul
torch.matmul(input, other, out=None)
功能:
适用性最多的,能处理batch、广播的矩阵:
- 如果第一个参数是一维,第二个是二维,那么给第一个提供一个维度
- 如果第一个是二维,第二个是一维,就是矩阵乘向量
- 带有batch的情况,可保留batch计算
- 维度不同时,可先广播,再batch计算
例子:
-
vector x vector
tensor1 = torch.randn(3) tensor2 = torch.randn(3) torch.matmul(tensor1, tensor2).size() torch.Size([])
-
matrix x vector
tensor1 = torch.randn(3, 4) tensor2 = torch.randn(4) torch.matmul(tensor1, tensor2).size() torch.Size([3])
-
batched matrix x broadcasted vecto
tensor1 = torch.randn(10, 3, 4) tensor2 = torch.randn(4) torch.matmul(tensor1, tensor2).size() torch.Size([10, 3])
-
batched matrix x batched matrix
tensor1 = torch.randn(10, 3, 4) tensor2 = torch.randn(10, 4, 5) torch.matmul(tensor1, tensor2).size() torch.Size([10, 3, 5])
总结
对位相乘用torch.mul
,二维矩阵乘法用torch.mm
,batch二维矩阵用torch.bmm
,batch、广播用torch.matmul