Pytorch中将tensor拉平的几种方法


前言

当我们在搭建网络时,tensor进入全连接层/GAP/GMP/分类器之前需要对tensor进行拉平操作,保留某个维度或者去除某个维度,本文试着总结一下常见的将tensor拉平的方法,如有问题希望大家批评指正。


一、卷积神经网络提取特征的流程

在计算机视觉领域,无论是图像分类还是目标检测,CNN常被用作图片特征提取的Backbone(主干网络)。CNN经过某些卷积操作生成feature map,降低分辨率,增大通道数。在进入最后的全连接层/分类器之前时,特征信息最多,往往此时需要保留通道数而忽略图片的宽高。本文以上一篇文章的MobileNetV2为例,阐述几种tensor拉平的方法。

二、几种常见方法

1.view():元素总数不变改变形状

MobileNetV2的forward部分:

 def forward(self, x):
        #2,3,32,32
        x = self.conv1(x)
        #2,3,32,32
        x = self.bottleneck1(x)
        x = self.bottleneck2(x)
        x = self.bottleneck3(x)
        x = self.bottleneck4(x)
        x = self.bottleneck5(x)
        x = self.bottleneck6(x)
        x = self.bottleneck7(x)
        #2,320,4,4
        x = self.conv2(x)
        #2,1280,4,4
        x = self.avgpool(x)
        #2,1280,1,1
        #tensor拉平发生的位置
        x = x.view(x.size[0],-1)
        #2,1280
        x = self.linear(x)
        #2,1000
        return x

Pytorch中tensor的输入格式为[B,C,H,W],分别代表batch_size,channels,高,宽。简单回顾一下,假设输入的tensor为[2,3,32,32],经过forward到全连接层之前的tensor变为[2,1280,1,1],分辨率降低,通道数变多,我们的目的是将tensor拉平,即只要batch_size和channels,方便后续分类。此时可以采用view()操作,这也是最常见的操作。

'''
view()是根据元素总数来改变tensor形状的,即变形后的tensor元素总数不变
本例中元素总数为2*1280*1*1
x.size[0]是x的第一个维度batch_size,本例中为2,-1代表自动计算该维度
想要去掉H,W则只需指定第一个维度的batch_size自动计算第二个维度即可,因为H,W经过卷积后均为1
'''

x = x.view(x.size[0],-1)

2.flatten():将指定维度合并为一个维度

forward中可以修改如下:

 def forward(self, x):
        #2,3,32,32
        x = self.conv1(x)
        #2,3,32,32
        x = self.bottleneck1(x)
        x = self.bottleneck2(x)
        x = self.bottleneck3(x)
        x = self.bottleneck4(x)
        x = self.bottleneck5(x)
        x = self.bottleneck6(x)
        x = self.bottleneck7(x)
        #2,320,4,4
        x = self.conv2(x)
        #2,1280,4,4
        x = self.avgpool(x)
        #2,1280,1,1
        #tensor拉平发生的位置
        #flatten的两种方式
        #将第一维之后的维度合并
        #x = torch.flatten(x,1)
        #x = x.flatten(1)
        #2,1280
        x = self.linear(x)
        #2,1000
        return x

flatten原型如下:

flatten(input,start_dim=0,end_dim=-1)

其中input为输入的tensor,start_dim为起始维度,end_dim为终止维度。flatten的功能为将start_dim到end_dim的维度合并为一个维度。本例中将128011合并为一个维度1280

3.squeeze():去掉维度数为1的维度

由于本例的特殊性,最后的H,W均为1,则可直接用squeeze()去掉维度为1的维度

 def forward(self, x):
        #2,3,32,32
        x = self.conv1(x)
        #2,3,32,32
        x = self.bottleneck1(x)
        x = self.bottleneck2(x)
        x = self.bottleneck3(x)
        x = self.bottleneck4(x)
        x = self.bottleneck5(x)
        x = self.bottleneck6(x)
        x = self.bottleneck7(x)
        #2,320,4,4
        x = self.conv2(x)
        #2,1280,4,4
        x = self.avgpool(x)
        #2,1280,1,1
        #tensor拉平发生的位置
        x = x.squeeze()
        #2,1280
        x = self.linear(x)
        #2,1000
        return x

未完待续……

没有硬件条件,需要云服务的同学可以扫码看看:
请添加图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

十小大

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值