torch.sequeeze
对维度进行压缩,以下为示例:
import torch
a = torch.zeros(2, 6, 1, 3, 1) # torch.Size([2, 6, 1, 3, 1])
# 对维度进行压缩
# a.squeeze(N) 若第 N 维的维度为 1,删除; 不为 1, 保留。
a_1 = a.squeeze(2)
print(a_1.shape) # torch.Size([2, 6, 3, 1])
a_2 = a.squeeze(1)
print(a_2.shape) # torch.Size([2, 6, 1, 3, 1])
# torch.squeeze(a,N) 与上面用法一致。
a_3 = torch.squeeze(a, 4)
print(a_3.shape) # torch.Size([2, 6, 1, 3])
# squeeze(a) 删除 a 中所有维度为1的维
a_4 = torch.squeeze(a)
print(a_4.shape) # torch.Size([2, 6, 3])
torch.unsequeeze
对维度进行扩充,以下为示例:
import torch
# 对维度进行扩充
b = torch.zeros(3, 2)
# b.unsqueeze(N) 在第 N 维加上唯数为 1 的维度
b_1 = b.unsqueeze(0)
print(b_1.shape) # torch.Size([1, 3, 2])
# torch.squeeze(a,N)
b_2 = torch.unsqueeze(b, 0)
print(b_2.shape) # torch.Size([1, 3, 2])