背景
pytorch已经有了unsqueeze操作来增加tensor的维度,除了这个可以使用一个更显示,更直观,更简单的方法 -> 插入None来伪装一个axis
具体方法
import torch
x = torch.randn(8)
print(x.shape) # torch.Size([8])
1 在所有维度之前插入一个维度
x = torch.randn(8)
x = x[None, :]
print(x.shape) # torch.Size([1, 8])
2 一次性插入多个维度
x = torch.randn(8)
x = x[None, None, :]
print(x.shape) # torch.Size([1, 1, 8])
3 省去冒号,隐式的把当前所有维度置于最后面
x = torch.randn(8)
x = x[None]
print(x.shape) # torch.Size([1, 8])
x = torch.randn(8)
x = x[None, None]
print(x.shape) # torch.Size([1, 1, 8])
3 利用冒号,灵活的操纵维度
import torch
a = torch.randn(4,3)
print(a.shape) # torch.Size([4, 3])
b = a[None, :, :] # == a[None]
print(b.shape) # torch.Size([1, 4, 3])
c = a[:, None, :] # == a[:, None]
print(c.shape) # torch.Size([4, 1, 3])
d = a[:, :, None]
print(d.shape) # torch.Size([4, 3, 1])
这里面需要注意一下a[:, None, :] 等价于 a[:, None],因为后者省去了最后一个冒号/维度
Reference
https://sparrow.dev/adding-a-dimension-to-a-tensor-in-pytorch/
.
.
.
.
.