1. reshape
a = torch.rand(4, 1, 28, 28)
print(a.shape)
b = a.reshape(4, 28*28)
print(b.shape)
torch.Size([4, 1, 28, 28])
torch.Size([4, 784])
2. 维度添加和挤压
添加维度:(在所填写的数字前面添加)
a = torch.tensor([1.2, 2.3])
print(a)
b = a.unsqueeze(0) # 在第0个维度前添加一个维度
print(b)
c = a.unsqueeze(1) # 在第1个维度前添加一个维度
print(c)
tensor([1.2000, 2.3000])
tensor([[1.2000, 2.3000]])
tensor([[1.2000],
[2.3000]])
减少维度(给出减少维度的编号即可,只可以减少为1的)
a = torch.rand(1, 1, 28, 28)
print(a.shape)
b = a.squeeze(0) # 减少第0个维度
print(b.shape)
c = a.squeeze(2) # 减少第2个维度,但不为1,不会减少
print(c.shape)
torch.Size([1, 1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 1, 28, 28])