数据维度转换 torch.reshape()
1. 为什么要使用 reshape()
函数
- 对于不同的网络结构如:一维卷积核、二维卷积核等,对输入数据维度要求并不相同,
reshape()
函数提供了非常方便的数据维度转换功能 torch.reshape()
提供了数据维度转换功能,在使用对数据维度有一定限制的网络结构时,一定要注意维度问题!!
2. 维度问题实例
-
nn.Conv2d()
接受的数据的维度必须是 (N, C, H, W)四维 或者 (C, H, W)三维的Tensor数据,对于下面的样例,就是因为数据维度问题而报错torch.manual_seed(0) input = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) # torch.Size([3, 3]) conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(2, 2),) output = conv(input) # RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [3, 3]
(N, C, H, W):
N表示多少数据数量,C表示数据的通道数,H表示数据的行数,W表示数据的列数- 输入数据为三维
(C, H, W)
:表示需要进行卷积的数据只有一个通道数为C行数为H列数为W,在图像处理中具体表现为只需要卷积一张图片数据 - 输入数据为思维
(N, C, H, W):
表示共计有N个数据需要进行卷积,是一个数量为N的批处理
3. torch.reshape() 函数介绍
-
为了保证数据的维度正确,在进行卷积之前,可以使用
torch.reshape()
函数塑造数据格式torch.reshape(input, shape)
介绍input:
指定需要传入的数据shape:
使用元组形式表明需要设置的数据的维度- 指定的shape的元素数量一定要和input中的元素数量相对应,可以使用-1让函数自动计算某个维度的数量,一个shape参数只能包含一个-1
- 函数以行优先策略编排数据元素
4. torch.reshape()函数使用实例
-
设二维Tensor(3*3),模拟一张(3*3)的图像,需要进行卷积
torch.manual_seed(0) input = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) # torch.Size([3, 3]) input = torch.reshape(input = input, shape = (-1, 3, 3)) input.shape # torch.Size([1, 3, 3])
- 在图像处理中,其实图像数据本身的维度是知道的,类似于上例中的二维数据模拟一个 3*3 的图像,我们想将它扩充为三维数据,在指定
shape=(-1, 3, 3)
时,会自动计算对应位置的个数
- 在图像处理中,其实图像数据本身的维度是知道的,类似于上例中的二维数据模拟一个 3*3 的图像,我们想将它扩充为三维数据,在指定
-
设二维Tensor(3*9),模拟了一张三通道的图像,需要进行卷积,每个通道在第二个维度上堆叠构建成输入数据
torch.manual_seed(0) input = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], [13, 14, 15], [16, 17, 18], [19, 20, 21], [22, 23, 24], [25, 26, 27]]) # torch.Size([3, 9]) input = torch.reshape(input = input, shape = (-1, 3, 3)) input.shape # torch.Size([3, 3, 3])
torch.reshape
会直接计算确定通道数,不需要直接指定
-
同样对于上面的例子,假设不是三通道图像,而是三张一通道图像在第二维度上的堆叠,构建了输入数据
torch.manual_seed(0) input = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], [13, 14, 15], [16, 17, 18], [19, 20, 21], [22, 23, 24], [25, 26, 27]]) # torch.Size([3, 9]) input = torch.reshape(input = input, shape = (-1, 1, 3, 3)) input.shape # torch.Size([3, 1, 3, 3])
- 指定的
shape=(-1, 1, 3, 3)
函数将根据后三项(1, 3, 3)
图像数据中表示为一通道3*3的图像,通过行优先策略确定批数据中的图像数量 N。
- 指定的
总而言之:torch.reshape()
提供了数据维度转换功能,在使用对数据维度有一定限制的网络结构时,一定要注意维度问题!!