Pytorch中的unsqueeze函数是用来在张量中添加一个或多个新的维度的。可以通过提供一个包含想要添加的维度位置的列表来一次性添加多个维度。例如,unsqueeze([0, 2])会在0和2的位置添加新的维度。这样可以让你对不同形状的张量进行操作,比如乘法。返回的张量和原始张量共享相同的数据。
假设有一个形状为(3, 4)的张量x,在第1维和第3维添加新的维度,使得张量的形状变为(3, 1, 4, 1)。可以这样做:
import torch
x = torch.randn(3, 4)
print(x.shape) # (3, 4)
x1 = x.unsqueeze([1, 3])
print(x1.shape) # (3, 1, 4, 1)
相反,如果想删除多余的维度,可以使用torch.squeeze函数。这个函数会返回一个删除了所有大小为1的维度的张量。例如,如果有一个形状为(3, 1, 4, 1)的张量x,可以这样做:
import torch
x = torch.randn(3, 1, 4, 1)
print(x.shape) # (3, 1, 4, 1)
x2 = x.squeeze()
print(x2.shape) # (3, 4)