resnet 论文_ResNet 残差网络论文阅读及示例代码

论文阅读

其实论文的思想在今天看来是不难的,不过在当时 ResNet 提出的时候可是横扫了各大分类任务,这个网络解决了随着网络的加深,分类的准确率不升反降的问题。通过一个名叫“残差”的网络结构(如下图所示),使作者可以只通过简单的网络深度堆叠便可达到提升准确率的目的。

5fa272c69e976d18637b2d5a437a7d4e.png

残差结构

残差结构的处理过程分成两个部分,左边的 F(X) 与右边的 X,最后结果为两者相加。其中右边那根线不会对 X 做任何处理,所以没有可学习的参数;左边部分 F(X) 为网络中负责学习特征的部分,把整个残差结构看做是 H(X) 函数的话,则负责学习的部分可以表示为 H(X)=F(X)-X,这个结构学习的其实是输出结果与输入的差值,这也是残差名字的由来。完整的 ResNet 网络由多个上图中所示的残差结构组成,每个结构学习的都是输出与输入之间的差值,通过步步逼近,达到了比直接学习输入好得多的效果。

文中残差结构的具体实现分为两种,首先介绍 ResNet-18 与 ResNet-34 使用的残差结构称为 Basic Block,如下图所示,图中的结构包含了两个卷积操作用于提取特征。

fcbab7ca9e9abe866b255ae592f91571.png

Basic Block

对应到代码中,这是 Pytorch 自带的 ResNet 实现中的一部分,跟上图对应起来看更加好理解,我个人比较喜欢论文与代码结合起来看,因为我除了需要知道原理之外,也要知道如何去使用,而代码更给我一种一目了然的感觉:

class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None): super(BasicBlock, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d if groups != 1 or base_width != 64: raise ValueError('BasicBlock only supports groups=1 and base_width=64') if dilation > 1: raise NotImplementedError("Dilation > 1 not supported in BasicBlock") # Both self.conv1 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = norm_layer(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3(planes, planes) self.bn2 = norm_layer(planes) self.downsample = downsample self.stride = stride def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out

另一种残差结构称为 Bottleneck,就是瓶颈的意思:

e6e01b6edbc6a3e73ffa22363bcaf589.png

作者起名字真的很形象,网络结构也正如这瓶颈一样,首先做一个降维,然后做卷积,然后升维,这样做的好处是可以大大减少计算量,专门用于网络层数较深的的网络,ResNet-50 以上的网络都有这种基础结构构成(不同层级的输入输出维度可能会不一样,但结构类似):

4a8c8ecaec4e3823669bc1bb5d10e7dc.png

Pytorch 中的代码,注意到上图中为了减少计算量,作者将 256 维的输入缩小了 4 倍变为 64 进入卷积,在升维时需要升到 256 维,对应代码中的 expansion 参数:

class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None): super(Bottleneck, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d width = int(planes * (base_width / 64.)) * groups # Both self.conv2 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv1x1(inplanes, width) self.bn1 = norm_layer(width) self.conv2 = conv3x3(width, width, stride, groups, dilation) self.bn2 = norm_layer(width) self.conv3 = conv1x1(width, planes * self.expansion) self.bn3 = norm_layer(planes * self.expansion) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out

由上面介绍的基本结构再加上池化以及全连接层,就构成了各种完整的网络:

001d8512f741a9dbd7b29cb7091c1a3c.png

图中的网络在 Pytorch 中都已经集成进去了,而且都是预训练好的,我们可以在预训练好的模型上面训练自己的分类器,大大减少我们的训练时间。下面介绍一下如何使用 ResNet。

在 Pytorch 中使用 ResNet

Pytorch 是一个对初学者很友好的深度学习框架,入门的话非常推荐,官方提供了一小时入门教程:https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html

在 Pytorch 中使用 ResNet 只需要 4 行代码:

from torch import nn# torchvision 专用于视觉方面import torchvision  # pretrained :使用在 ImageNet 数据集上预训练的模型model = torchvision.models.resnet18(pretrained=True)# 修改模型的全连接层使其输出为你需要类型数,这里是10# 由于使用了预训练的模型 而预训练的模型输出为1000类,所以要修改全连接层# 若不使用预训练的模型可以直接在创建模型时添加参数 num_classes=10 而不需要修改全连接层model.fc = nn.Linear(model.fc.in_features, 10)
  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值