# PyTorch 中的矩阵、向量、标量之间的乘法
# 一、torch.mul()
# 注意:torch.mul() 是支持广播操作
# torch.mul(input, value, out=None)
# 用标量值 value 乘以输入 input 的每个元素,并返回一个新的结果张量。 out = tensor ∗ value
# 如果输入是FloatTensor or DoubleTensor类型,则 value 必须为实数,否则须为整数。【译注:似乎并非如此,无关输入类型,value取整数、实数皆可。】
# 参数:
# input (Tensor) – 输入张量
# value (Number) – 乘到每个元素的数
# out (Tensor, optional) – 输出张量
import torch
a = torch.randn(3)
print("a : ", a) # tensor([-1.6289, 0.2446, -0.3691])
print("a.size() : ", a.size()) # torch.Size([3])
mul_a_100 = torch.mul(a, 100)
print("mul_a_100 : ", mul_a_100) # tensor([-162.8945, 24.4566, -36.9136])
print("mul_a_100.size() : ", mul_a_100.size()) # torch.Size([3])
print("*" * 50)
# 两个张量 input, other 按元素进行相乘,并返回到输出张量。即计算 outi = inputi ∗ otheri
# 两计算张量形状不须匹配,但总元素数须一致。 注意:当形状不匹配时,input的形状作为输入张量的形状。
#
# 参数:
#
# input (Tensor) – 第一个相乘张量
# other (Tensor) – 第二个相乘张量
# out (Tensor, optional) – 结果张量
c = torch.randn(4, 4)
print("c.size() : ", c.size()) # torch.Size([4, 4])
# d = torch.randn(2, 8) # torch.Size([2, 8]) 该形状不符合广播条件
# RuntimeError: The size of tensor a (4) must match the size of tensor b (8) at non-singleton dimension 1
d = torch.randn(1, 4) # 该形状符合广播条件
print("d.size() : ", d.size())
mul_c_d = torch.mul(c, d)
print("mul_c_d.size() : ", mul_c_d.size()) # torch.Size([4, 4])
# 二、torch.mm()
# 注意,torch.mm()不支持广播(broadcast)。
# torch.mm(mat1, mat2, out=None) → Tensor
# 对矩阵mat1和mat2进行相乘。 如果mat1 是一个n×m 张量,mat2 是一个 m×p 张量,将会输出一个 n×p 张量out。
print("^" * 50)
mat1 = torch.randn(2, 3)
print(mat1.size()) # torch.Size([2, 3])
# mat2 = torch.randn(1, 3) # 该形状不支持广播
# print(mat2.size()) # torch.Size([1, 3])
# RuntimeError: size mismatch, m1: [2 x 3], m2: [1 x 3] at /pytorch/aten/src/TH/generic/THTensorMath.cpp:752
mat2 = torch.randn(3, 4)
print(mat2.size()) # torch.Size([3, 4])
mm = torch.mm(mat1, mat2)
print(mm.size()) # torch.Size([2, 4])
# 三、torch.mv()
# 注意,torch.mv()不支持广播(broadcast)
# torch.mv(mat, vec, out=None) → Tensor
# 对矩阵mat和向量vec进行相乘。 如果mat 是一个n×m张量,vec 是一个m元 1维张量,将会输出一个n 元 1维张量。
print("-" * 50)
mat = torch.randn(2, 3)
print(mat.size()) # torch.Size([2, 3])
# vec = torch.randn(2)
# RuntimeError: size mismatch, [2 x 3], [2] at /pytorch/aten/src/TH/generic/THTensorMath.cpp:631
vec = torch.randn(3)
print(vec.size()) # torch.Size([3])
mv = torch.mv(mat, vec)
print(mv.size()) # torch.Size([2])
# 四、torch.dot()
# 注意,torch.dot()不支持广播(broadcast)
# torch.dot(tensor1, tensor2) → Tensor
# 计算两个张量的点乘(内乘),两个张量都为1-D 向量
print("=" * 50)
# x = torch.tensor([2, 3, 2]) # 该形状不支持广播
# print(x.size()) # torch.Size([3])
# RuntimeError: inconsistent tensor size, expected tensor [3] and src [2] to have the same number of elements,
# but got 3 and 2 elements respectively
x = torch.tensor([2, 3])
print(x.size()) # torch.Size([2])
y = torch.tensor([4, 1])
print(y.size()) # torch.Size([2])
dot = torch.dot(x, y)
print(dot) # tensor(11)
print(dot.size()) # torch.Size([])
print("~" * 50)
# 五、torch.matmul()
# 注意:torch.matmul() 支持广播
# torch.matmul(input, other, out=None) → Tensor
# 两个张量的矩阵乘积
# 计算结果取决于张量的维度:
# 1)如果两个张量都是 1 维,返回结果为 the dot product (scalar) 【点乘(标量)】
# 2)如果两个张量都是 2 维,返回结果为 the matrix-matrix product (矩阵乘积)
# 3)如果第一个参数是 1 维,第二个参数是 2 维,为了矩阵乘法的目的,在第一维上加 1(达到扩充维度的目的),
# 矩阵计算完成之后,第一维加上的 1 将会被删掉。
# 4)如果第一个参数是 2 维,第二个参数是 1 维,返回结果为 the matrix-vector product (矩阵向量乘积)
# 5)如果两个参数至少是 1 维且至少一个参数为 N 维(其中N> 2),则返回 batched matrix multiply (批处理矩阵乘法)
# 如果第一个参数是 1 维,则在其维数之前添加 1,以实现批量矩阵乘法并在计算之后删除 1。
# 如果第二个参数是 1 维,则在其维数之前添加 1,以实现批量矩阵乘法并在计算之后删除 1。
# 非矩阵(即批处理)尺寸被广播(因此必须是可广播的)。
# 例如,如果 input 的张量是 j×1×n×m ,
# other 的张量是 k×m×p,
# out 的张量将会是 j×k×n×p
# case 1:vector x vector
tensor1 = torch.randn(3)
print(tensor1.size()) # torch.Size([3])
tensor2 = torch.randn(3)
print(tensor2.size()) # torch.Size([3])
matmul_1_2 = torch.matmul(tensor1, tensor2)
print(matmul_1_2) # tensor(0.2001) -- scalar
print(matmul_1_2.size()) # torch.Size([])
# case 4: matrix x vector (该情况下不支持广播,matrix的列数必须要和vector的行数一致才能进行计算)
tensor3 = torch.randn(3, 4)
print(tensor3.size()) # torch.Size([3, 4])
tensor4 = torch.randn(4)
print(tensor4.size()) # torch.Size([4])
matmul_3_4 = torch.matmul(tensor3, tensor4)
print(matmul_3_4) # tensor([ 0.8020, 0.2547, -1.2333])
print(matmul_3_4.size()) # torch.Size([3])
# case 5:batched matrix x broadcasted vector
a = torch.randn(10, 3, 4)
print(a.size()) # torch.Size([10, 3, 4])
b = torch.randn(4)
print(b.size()) # torch.Size([4])
matmul_a_b = torch.matmul(a, b)
print(matmul_a_b.size()) # torch.Size([10, 3])
# case 5:batched matrix x batched matrix
c = torch.randn(10, 3, 4)
print(c.size()) # torch.Size([10, 3, 4])
d = torch.randn(10, 4, 5)
print(d.size()) # torch.Size([10, 4, 5])
matmul_c_d = torch.matmul(c, d)
print(matmul_c_d.size()) # torch.Size([10, 3, 5])
# case 5:batched matrix x broadcasted matrix
m = torch.randn(10, 3, 4)
print(m.size()) # torch.Size([10, 3, 4])
n = torch.randn(4, 5)
print(n.size()) # torch.Size([4, 5])
matmul_m_n = torch.matmul(m, n)
print(matmul_m_n.size()) # torch.Size([10, 3, 5])
矩阵维度必须一致_PyTorch 中的矩阵、向量、标量之间的乘法
最新推荐文章于 2023-09-14 22:35:16 发布