1 torch.mul()
-
用标量值
value
乘以输入input
的每个元素,并返回一个新的结果张量。 \( out=tensor ∗ value \)。如果输入是FloatTensor or DoubleTensor类型,则value
必须为实数,否则须为整数。
torch.mul(input, value, out=None)
参数 | 描述 |
input (Tensor) | 输入张量 |
value (Number) | 乘到每个元素的数 |
out (Tensor) | 可选,输出张量 |
栗子:
>>> a = torch.randn(3)
>>> a
-0.9374
-0.5254
-0.6069
[torch.FloatTensor of size 3]
>>> torch.mul(a, 100)
-93.7411
-52.5374
-60.6908
[torch.FloatTensor of size 3]
-
两个张量
input
,other
按元素进行相乘,并返回到输出张量。即计算\( out_i=input_i ∗ other_i \)。两计算张量形状不须匹配,但总元素数须一致。
torch.mul(input, other, out=None)
参数 | 描述 |
input (Tensor) | 第一个相乘张量 |
other (Tensor) | 第二个相乘张量 |
out (Tensor) | 可选,输出张量 |
栗子:
>>> a = torch.randn(4,4)
>>> a
-0.7280 0.0598 -1.4327 -0.5825
-0.1427 -0.0690 0.0821 -0.3270
-0.9241 0.5110 0.4070 -1.1188
-0.8308 0.7426 -0.6240 -1.1582
[torch.FloatTensor of size 4x4]
>>> b = torch.randn(2, 8)
>>> b
0.0430 -1.0775 0.6015 1.1647 -0.6549 0.0308 -0.1670 1.0742
-1.2593 0.0292 -0.0849 0.4530 1.2404 -0.4659 -0.1840 0.5974
[torch.FloatTensor of size 2x8]
>>> torch.mul(a, b)
-0.0313 -0.0645 -0.8618 -0.6784
0.0934 -0.0021 -0.0137 -0.3513
1.1638 0.0149 -0.0346 -0.5068
-1.0304 -0.3460 0.1148 -0.6919
[torch.FloatTensor of size 4x4]
2 torch.mm()
处理二维矩阵的乘法,而且也只能处理二维矩阵,其他维度要用torch.matmul()。torch.mm(a, b)
是矩阵a和b矩阵相乘,比如a的维度是(1, 2),b的维度是(2, 3),返回的就是(1, 3)的矩阵
torch.mm(input, mat2, out=None)
- 栗子
mat1 = torch.randn(2, 3)
mat2 = torch.randn(3, 3)
torch.mm(mat1, mat2)
tensor([[ 0.4851, 0.5037, -0.3633],
[-0.0760, -3.6705, 2.4784]])
3 torch.bmm()
torch.bmm(input, mat2, out=None)
看函数名就知道,在torch.mm
的基础上加了个batch计算,不能广播。
4 torch.matmul()
torch.matmul(input, other, out=None)
功能:
适用性最多的,能处理batch、广播的矩阵:
- 如果第一个参数是一维,第二个是二维,那么给第一个提供一个维度
- 如果第一个是二维,第二个是一维,就是矩阵乘向量
- 带有batch的情况,可保留batch计算
- 维度不同时,可先广播,再batch计算
栗子:
- vector x vector
tensor1 = torch.randn(3)
tensor2 = torch.randn(3)
torch.matmul(tensor1, tensor2).size()
torch.Size([])
- matrix x vector
tensor1 = torch.randn(3, 4)
tensor2 = torch.randn(4)
torch.matmul(tensor1, tensor2).size()
torch.Size([3])
- batched matrix x broadcasted vecto
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4)
torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3])
- batched matrix x batched matrix
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(10, 4, 5)
torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3, 5])
总结:
对位相乘用
torch.mul
,二维矩阵乘法用torch.mm
,batch二维矩阵用torch.bmm
,batch、广播用torch.matmul
参考: