unsqueeze()
用于增加一个维度。
先假设有如下一维的Tensor.
a=torch.Tensor([1,2])
print(a.shape)
假设我们现在有一个2*2的矩阵b,要与a相乘,最规范的是应该a的形状要变成2*1才对,现在是2。所以要增加一个维度。使用tensor的一个函数unsqueeze(dim)。参数中指明哪一个维度要增加一维。我们要对a在第二维增加一个维度。
a=a.unsqueeze(1)
print(a.shape)
我们来要给直观的对比。
定义一个矩阵b,其形状为2*2,现在可以与矩阵a(2*1)相乘了。
b=torch.Tensor([[1,2],[3,4]])
torch.matmul(b,a)
b*a=(2*2)*(2*1)=2*1,结果的矩阵为:
print(torch.matmul(b,a).shape)
反过来由于我们发现a*b之后的那个矩阵最后一个维度是1。所以我们可以使用squeeze()
函数来删除最后一个维度。
c=torch.matmul(b,a)
c=c.squeeze(1)
c