PyTorch常用函数 | squeeze与unsqueeze函数 | flatten函数 | Pytorch中的各种乘法 | mul与mv与mm与dot与matmul函数

一、squeeze与unsqueeze函数

  顾名思义,squeeze函数的作用是压缩(降维);unsqueeze函数的作用是解压(升维)。

1.squeeze函数

torch.squeeze(input, dim=None, *, out=None)

功能

  • 返回一个所有维度不为1的张量。
    比如,输入张量维度为A×1×B×C×1×D,则经squeeze函数的输出张量维度为A×B×C×D
  • dim参数设置时,只在所设置的dim维度进行张量的压缩,如果所设置的dim维度的尺寸不为1,则张量不会发生任何改变。
    比如,输入张量的维度为A×1×Bsqueeze(input, 0)函数不会使得输出张量的维度有任何改变。但是squeeze(input, 1)函数会使得输出张量的维度变为A×B

需要注意的是:

  1. 返回的张量与输入张量共享内存,因此,其中一个张量的内容改变,另一个张量的内容也会改变
  2. 如果一个张量batch维度设置为1,那么经squeeze函数将会移除batch维度,这会导致错误

参数

  • input(Tensor):输入张量
  • dim(int, optional):如果这个参数被设置,则输入张量仅在这个维度(dim)上进行压缩

例子:

x = torch.zeros(2, 1, 2, 1, 2)
x.size()
=>torch.Size([2, 1, 2, 1, 2])
y = torch.squeeze(x)
y.size()
=>torch.Size([2, 2, 2])
y = torch.squeeze(x, 0)
y.size()
=>torch.Size([2, 1, 2, 1, 2])
y = torch.squeeze(x, 1)
y.size()
=>torch.Size([2, 2, 1, 2])

2.unsqueeze函数

torch.unsqueeze(input, dim)

功能

  • 返回一个在指定位置插入尺寸为 1 的新张量
  • dim参数的区间为[-input.dim() - 1, input.dim() + 1),负的dim维度相当于dim = dim + input.dim() + 1

参数

  • input (Tensor):输入张量
  • dim (int):插入的索引

案例

# shape=4
x = torch.tensor([1, 2, 3, 4])  
# shape=(1,4)
torch.unsqueeze(x, 0)
=>tensor([[ 1,  2,  3,  4]])
# shape=(4,1)
torch.unsqueeze(x, 1)
=>t
  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

幼稚的人呐

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值