pytorch 使用预训练模型如resnet、vgg等并修改部分结构

pytorch 使用预训练模型并修改部分结构

在一些常见的如检测、分类等计算机视觉任务中,基于深度学习的方法取得了很好的结果,其中一些经典模型也往往成为相关任务及比赛的baseline。在pytorch的视觉库torchvision中,提供了models模块供我们直接调用这些经典网络,如VGG,Resnet等。使用中往往不能直接使用现成的模型,需要进行一些修改。实际上我们可以很方便的在pytorch中使用并修改模型。

1. 直接使用pytorch中的经典模型

直接通过models调用即可,如

from torchvision import models

res101 = models.resnet101(pretrained=True)
vgg19 = models.vgg19_bn(pretrained=True)

如下如所示,models模块的__init__.py 包含了一系列不同的网络结构
在这里插入图片描述
以及网络模型的不同层数的结构,如resnet50, resnet101, vgg16, vgg19等
我们只需查阅手册或源码寻找是否有这个网络模型,有的话直接拿来用即可。参数 pre_trained为True时表示模型参数是在ImageNet预训练过的,否则就是随机初始化的参数。

在首次使用时,pytorch会自动下载模型文件,保存在用户cache目录内
在这里插入图片描述
在参加一些图像检测、分类、分割比赛时,或者一些不需要大幅修改网络结构的场景,可以直接采用pytorch自带的网络结构,无需自行搭建。

2. 模型结构的修改

在分类问题上,模型的最后一层一般是一个全连接层,输出的神经元个数就是类别信息,最后输出结果是一个浮点向量,大小表示某一类别的可能性,数值越大说明越倾向于分为该类。
显然直接使用预训练的网络不加修改那么总类别数就是固定的,当我们使用的场景类别数不一致时,就要自行修改模型的最后一层。那么如何进行替换和修改呢?

我们知道,在自定义网络结构时,通常是:

class myModel(Module):
  def __init__(self):
    # 模型结构
    self.conv1 = xxxx
    self.fc1 = xxxxx
    self.m = nn.Sequential(a,b,c...)
  def forward(self,x):
  # 前向传播

这样的形式。换言之,模型的每一层都记录在了这个模型类的实例的成员变量里。因此只要我们知道要修改的那一层叫什么名字,就能够进行修改。

例如 对resnet,最后一层全连接层就叫fc,所以我们可以:

res101 = models.resnet101(pretrained=True)
numFit = res101.fc.in_features
res101.fc = nn.Linear(numFit, numClass)

res101.fc就是这个网络的最后一层全连接层,in_feature是输出神经元数量,我们将它修改为输入神经元不变(也不能变,不然就出错了)输出神经元为我们需要的类别数的全连接网络。

有时候分类任务不光要输出类别,也要输出置信度,通常置信度就是分类为这个类别的概率,既然是概率,就要满足 0 ≤ P i ≤ 1 , Σ i = 1 N P i = 1 0\le P_{i}\le 1, \Sigma_{i=1}^{N}P_{i} = 1 0Pi1,Σi=1NPi=1
由于全连接网络直接输出的结果往往不能称之为“置信度”(只有大小之分,不满足0-1之间,和也不是1),通常会在后面加一层softmax作为激活函数,这样输出结果就是一个概率值了:

res101.fc = nn.Sequential(nn.Linear(numFit, numClass), nn.Softmax(dim=1))

以上是最简单的模型某一层就是一个单独的成员变量的情况。

那如果模型把好多东西塞进了一个Sequential怎么办呢?

例如vgg:
我们在torchvision/moduls/vgg.py 中找到VGG类的定义:
在这里插入图片描述
显然分类相关的3层全连接、激活函数、dropout都在一个Sequential类、名字叫做classifier的成员变量里,这种情况,我们需要把整个classifier都复写吗?

答案是不需要的。

我们知道Sequential同样继承自nn.Module类,这个类有一个成员变量叫做_modules
在这里插入图片描述
这是一个有序字典,存放了模块名称 - 模块内容 的键值对
每次新添加一层,都会做一次

self._modules[name] = module

这个操作。

这个name这里很有意思,一般我们很少给每一层网络都起一个名字,那默认的名字实际上是该模块索引的字符串形式。比如上述的vgg的classifier,它的第一层全连接,名字叫做’1’,最后一层name是’6’。这个名字部分以后有时间专门讨论一下。

回到这个问题,Sequential继承自nn.Module,自然也有这个字典。
所以对于vgg,我们可以:

vgg19 = models.vgg19_bn(pretrained=True)
vgg19.classifier._modules['6'] = nn.Sequential(nn.Linear(4096, numClass), nn.Softmax(dim=1))

就可以将最后一层全连接层替换掉了。

中间其他层也可以用类似的方式替换。

小结

总结来说,pytorch提供的网络模型还是比较实用的,对于不需要大幅修改的网络结构只要直接调用再局部修改就可以,满足一了一些简单的深度学习需求场景,可以不需要自己重新写一遍网路结构了。

在修改方面,基于pytorch模型的定义方式,我们只要知道其模型结构,这一点可以直接查找pytorch这部分的源码,了解到成员变量的名字,如果是Sequential,可以再通过_modules这个字典查找,都将能够较容易的找到被修改的那一层,直接替换成我们需要的结构即可。当然,替换后与训练的权重就不见了,取而代之的是随机初始化权重。

  • 61
    点赞
  • 238
    收藏
    觉得还不错? 一键收藏
  • 13
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 13
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值