Pytorch unsqueeze和squeeze 解释&示例

在PyTorch中,unsqueezesqueeze函数用于改变张量的形状(即增加或减少维度)。unsqueeze函数将张量的指定维度扩展,使其具有大小为1的新维度,而squeeze函数则移除张量中大小为1的维度。

以下是使用示例:

示例1:使用unsqueeze添加维度
import torch

# 创建一个大小为(3, 5)的张量
x = torch.rand(3, 5)

# 在第0维上添加一个大小为1的维度
x = torch.unsqueeze(x, 0)

print(x.shape)  # 输出: torch.Size([1, 3, 5])

在这个示例中,我们使用unsqueeze函数在第0维上添加了一个新维度,将原始的大小为(3, 5)的张量变成了大小为(1, 3, 5)的张量。

示例2:使用squeeze去除维度
import torch

# 创建一个大小为(1, 3, 5, 1)的张量
x = torch.rand(1, 3, 5, 1)

# 去除所有大小为1的维度
x = torch.squeeze(x)

print(x.shape)  # 输出: torch.Size([3, 5])

在这个示例中,我们使用squeeze函数去除所有大小为1的维度,将原始的大小为(1, 3, 5, 1)的张量变成了大小为(3, 5)的张量。

示例3:使用squeeze去除指定维度
import torch

# 创建一个大小为(1, 3, 5, 1)的张量
x = torch.rand(1, 3, 5, 1)

# 去除第0维和第3维上的大小为1的维度
x = torch.squeeze(x, 0)
x = torch.squeeze(x, 2)

print(x.shape)  # 输出: torch.Size([3, 5])

在这个示例中,我们使用squeeze函数去除第0维和第3维上的大小为1的维度,将原始的大小为(1, 3, 5, 1)的张量变成了大小为(3, 5)的张量。

一个复杂的示例
indices = torch.tensor([[1, 3], [0, 2]])
index = indices.unsqueeze(-1).repeat(1, 1, input_tensor.shape[-1])

这段代码主要是利用PyTorch的张量广播机制对给定的索引张量进行扩展,以便于在另一个张量中进行元素的收集。

首先,代码定义了一个大小为(2, 2)的索引张量indices,表示需要在一个张量中收集第1、3行和第0、2行的元素。

然后,使用unsqueeze函数在indices张量的最后一个维度(即列维度)上增加了一个维度,得到的形状是(2, 2, 1)。接着,使用repeat函数将这个张量在最后一个维度上重复input_tensor.shape[-1]次,得到的形状是(2, 2, D),其中D表示input_tensor张量的最后一个维度大小,也就是特征维度。

这样,就可以使用得到的index张量对input_tensor张量中需要收集的元素进行索引和收集操作。

总的来说,这段代码的作用是将一个大小为(2, 2)的索引张量扩展为一个大小为(2, 2, D)的张量,以便于在另一个张量中进行元素的收集。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值