x.squeeze(dim)
用途:进行维度压缩,去掉tensor中维数为1的维度
参数设置:如果设置dim=a,就是去掉指定维度中维数为1的
示例:
import torch
x = torch.tensor([[[1],[2]],[[3],[4]]])
print('x:',x)
x1 = x.squeeze()
print('x1:',x1)
x2 = x.squeeze(2)
print('x2:',x2)
输出:
x: tensor([[[1],
[2]],
[[3],
[4]]])
x1: tensor([[1, 2],
[3, 4]])
x2: tensor([[1, 2],
[3, 4]])
Process finished with exit code 0
x.unsqueeze(dim=a)
用途:进行维度扩充,在指定位置加上维数为1的维度
参数设置:如果设置dim=a,就是在维度为a的位置进行扩充
示例:
import torch
x = torch.tensor([1,2,3,4])
print(x)
x1 = x.unsqueeze(dim=0)
print(x1)
x2 = x.unsqueeze(dim=1)
print(x2)
y = torch.tensor([[1,2,3,4],[9,8,7,6]])
print(y)
y1 = y.unsqueeze(dim=0)
print(y1)
y2 = y.unsqueeze(dim=1)
print(y2)
输出:
x: tensor([1, 2, 3, 4])
x1: tensor([[1, 2, 3, 4]])
x2: tensor([[1],
[2],
[3],
[4]])
y: tensor([[1, 2, 3, 4],
[9, 8, 7, 6]])
y1: tensor([[[1, 2, 3, 4],
[9, 8, 7, 6]]])
y2: tensor([[[1, 2, 3, 4]],
[[9, 8, 7, 6]]])
Process finished with exit code 0