在PyTorch中,unsqueeze和squeeze函数用于改变张量的形状(即增加或减少维度)。unsqueeze函数将张量的指定维度扩展,使其具有大小为1的新维度,而squeeze函数则移除张量中大小为1的维度。
以下是使用示例:
示例1:使用unsqueeze添加维度
import torch
# 创建一个大小为(3, 5)的张量
x = torch.rand(3, 5)
# 在第0维上添加一个大小为1的维度
x = torch.unsqueeze(x, 0)
print(x.shape) # 输出: torch.Size([1, 3, 5])
在这个示例中,我们使用unsqueeze函数在第0维上添加了一个新维度,将原始的大小为(3, 5)的张量变成了大小为(1, 3, 5)的张量。
示例2:使用squeeze去除维度
import torch
# 创建一个大小为(1, 3, 5, 1)的张量
x = torch.rand(1, 3, 5, 1)
# 去除所有大小为1的维度
x = torch.squeeze(x)
print(x.shape) # 输出: torch.Size([3, 5])
在这个示例中,我们使用squeeze函数去除所有大小为1的维度,将原始的大小为(1, 3, 5, 1)的张量变成了大小为(3, 5)的张量。
示例3:使用squeeze去除指定维度
import torch
# 创建一个大小为(1, 3, 5, 1)的张量
x = torch.rand(1, 3, 5, 1)
# 去除第0维和第3维上的大小为1的维度
x = torch.squeeze(x, 0)
x = torch.squeeze(x, 2)
print(x.shape) # 输出: torch.Size([3, 5])
在这个示例中,我们使用squeeze函数去除第0维和第3维上的大小为1的维度,将原始的大小为(1, 3, 5, 1)的张量变成了大小为(3, 5)的张量。
一个复杂的示例
indices = torch.tensor([[1, 3], [0, 2]])
index = indices.unsqueeze(-1).repeat(1, 1, input_tensor.shape[-1])
这段代码主要是利用PyTorch的张量广播机制对给定的索引张量进行扩展,以便于在另一个张量中进行元素的收集。
首先,代码定义了一个大小为(2, 2)的索引张量indices,表示需要在一个张量中收集第1、3行和第0、2行的元素。
然后,使用unsqueeze函数在indices张量的最后一个维度(即列维度)上增加了一个维度,得到的形状是(2, 2, 1)。接着,使用repeat函数将这个张量在最后一个维度上重复input_tensor.shape[-1]次,得到的形状是(2, 2, D),其中D表示input_tensor张量的最后一个维度大小,也就是特征维度。
这样,就可以使用得到的index张量对input_tensor张量中需要收集的元素进行索引和收集操作。
总的来说,这段代码的作用是将一个大小为(2, 2)的索引张量扩展为一个大小为(2, 2, D)的张量,以便于在另一个张量中进行元素的收集。