torch维度转换
PIL读入的图片的格式是
(H,W, C)
numpy储存图片的格式是(batch_size, H, W, C)
通常卷积需要的是(batch_size,C,H, W)
因此需要进行维度转换。
1、numpy中的维度转换
numpy中使用reshape
来进行形状变换。
transpose
的作用是坐标轴变换,切换角度来看待问题。
n = np.random.randn(2, 3, 4)
reshape_n = n.reshape(-1, 12)
print(reshape_n.shape) # (2, 12)
transpose_n = n.transpose(1, 2, 0)
print(transpose_n.shape) # (3, 4, 12)
2、torch维度转换
使用torch.view
来进行变换,尺寸转换,总数量不变。
t = torch.randn(2, 3, 4)
view_t = t.view(-1, 12)
print(view_t.shape) # torch.Size([2, 12])
torch.squeeze()/torch.unsqueeze()
这个是用来压缩或者添加维度的。
squeeze(n)
只能压缩第n个维度为1
的维