15. 数据维度转换 -- torch.reshape

数据维度转换 torch.reshape()
1. 为什么要使用 reshape() 函数
  • 对于不同的网络结构如:一维卷积核、二维卷积核等,对输入数据维度要求并不相同,reshape()函数提供了非常方便的数据维度转换功能
  • torch.reshape() 提供了数据维度转换功能,在使用对数据维度有一定限制的网络结构时,一定要注意维度问题!!
2. 维度问题实例
  • nn.Conv2d()接受的数据的维度必须是 (N, C, H, W)四维 或者 (C, H, W)三维的Tensor数据,对于下面的样例,就是因为数据维度问题而报错

    torch.manual_seed(0)
    input = torch.Tensor([[1, 2, 3], 
                          [4, 5, 6], 
                          [7, 8, 9]]) # torch.Size([3, 3])
    conv = nn.Conv2d(in_channels=1,
                     out_channels=1,
                     kernel_size=(2, 2),)
    output = conv(input)
    # RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [3, 3]
    
    • (N, C, H, W):N表示多少数据数量,C表示数据的通道数,H表示数据的行数,W表示数据的列数
    • 输入数据为三维 (C, H, W):表示需要进行卷积的数据只有一个通道数为C行数为H列数为W,在图像处理中具体表现为只需要卷积一张图片数据
    • 输入数据为思维 (N, C, H, W): 表示共计有N个数据需要进行卷积,是一个数量为N的批处理
3. torch.reshape() 函数介绍
  • 为了保证数据的维度正确,在进行卷积之前,可以使用 torch.reshape() 函数塑造数据格式

    • torch.reshape(input, shape)介绍
      • input:指定需要传入的数据
      • shape:使用元组形式表明需要设置的数据的维度
      • 指定的shape的元素数量一定要和input中的元素数量相对应,可以使用-1让函数自动计算某个维度的数量,一个shape参数只能包含一个-1
      • 函数以行优先策略编排数据元素
4. torch.reshape()函数使用实例
  • 设二维Tensor(3*3),模拟一张(3*3)的图像,需要进行卷积

    torch.manual_seed(0)
    input = torch.Tensor([[1, 2, 3], 
                          [4, 5, 6], 
                          [7, 8, 9]]) # torch.Size([3, 3])
    input = torch.reshape(input = input,
                         shape = (-1, 3, 3)) 
    input.shape # torch.Size([1, 3, 3])
    
    • 在图像处理中,其实图像数据本身的维度是知道的,类似于上例中的二维数据模拟一个 3*3 的图像,我们想将它扩充为三维数据,在指定shape=(-1, 3, 3)时,会自动计算对应位置的个数
  • 设二维Tensor(3*9),模拟了一张三通道的图像,需要进行卷积,每个通道在第二个维度上堆叠构建成输入数据

    torch.manual_seed(0)
    input = torch.Tensor([[1, 2, 3], 
                          [4, 5, 6], 
                          [7, 8, 9],
                          [10, 11, 12],
                          [13, 14, 15],
                          [16, 17, 18],
                          [19, 20, 21],
                          [22, 23, 24],
                          [25, 26, 27]]) # torch.Size([3, 9]) 
    input = torch.reshape(input = input,
                         shape = (-1, 3, 3)) 
    input.shape # torch.Size([3, 3, 3])
    
    • torch.reshape 会直接计算确定通道数,不需要直接指定
  • 同样对于上面的例子,假设不是三通道图像,而是三张一通道图像在第二维度上的堆叠,构建了输入数据

    torch.manual_seed(0)
    input = torch.Tensor([[1, 2, 3], 
                          [4, 5, 6], 
                          [7, 8, 9],
                          [10, 11, 12],
                          [13, 14, 15],
                          [16, 17, 18],
                          [19, 20, 21],
                          [22, 23, 24],
                          [25, 26, 27]]) # torch.Size([3, 9]) 
    input = torch.reshape(input = input,
                         shape = (-1, 1, 3, 3)) 
    input.shape # torch.Size([3, 1, 3, 3])
    
    • 指定的 shape=(-1, 1, 3, 3) 函数将根据后三项 (1, 3, 3) 图像数据中表示为一通道3*3的图像,通过行优先策略确定批数据中的图像数量 N。

总而言之:torch.reshape() 提供了数据维度转换功能,在使用对数据维度有一定限制的网络结构时,一定要注意维度问题!!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

啥都想学的大学生

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值