pytorch中tensor的unsqueeze()函数和squeeze()函数的用处

unsqueeze()用于增加一个维度。
先假设有如下一维的Tensor.

a=torch.Tensor([1,2])
print(a.shape)

在这里插入图片描述
假设我们现在有一个2*2的矩阵b,要与a相乘,最规范的是应该a的形状要变成2*1才对,现在是2。所以要增加一个维度。使用tensor的一个函数unsqueeze(dim)。参数中指明哪一个维度要增加一维。我们要对a在第二维增加一个维度。

a=a.unsqueeze(1)
print(a.shape)

在这里插入图片描述
我们来要给直观的对比。
在这里插入图片描述

定义一个矩阵b,其形状为2*2,现在可以与矩阵a(2*1)相乘了。

b=torch.Tensor([[1,2],[3,4]])
torch.matmul(b,a)

b*a=(2*2)*(2*1)=2*1,结果的矩阵为:
在这里插入图片描述

print(torch.matmul(b,a).shape)

在这里插入图片描述
反过来由于我们发现a*b之后的那个矩阵最后一个维度是1。所以我们可以使用squeeze()函数来删除最后一个维度。

c=torch.matmul(b,a)
c=c.squeeze(1)
c

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

音程

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值