全网都在讲迁移学习,可你会写代码了吗?收藏我这个,10分钟开始你的迁移学习训练

前言

先,这里不讲迁移学习的理论,只讲实践,因为理论已经全网飞了~~,不懂得大家先去学理论,理论学了再来实操。
今天,在这里只想给大家介绍一种代码写法,适用于基于pytorch的迁移学习。

迁移学习主要用在分类模型上,把原本在ImageNet或其他数据集上训练好的模型,迁移到自己的项目上来。所以对于分类模型,我们要把模型最后一层(通常是全连接层)的输出分类类别数量改了。比如,原本在ImageNet上是分1000类,而我们的目标是分3类,就要把最后的类别数改为3。

整个过程分为两步:

第一步 加载预训练模型并修改类别数

from torchvision.models import densenet169, resnet50
import torch.nn as nn
import torch.optim as optim

### 加载模型, 并修改模型的最后一层  ####
model = densenet169(pretrained=True)
# 设定 pretrained = True 就会加载训练好的模型

# 修改模型的最后一层。不同的模型,最后一层的修改略有差异
arch = 'densenet169'
classes = 2  # 分类的数量
if 'resnet' in arch:
    # for param in model.layer4.parameters():
    model.fc = nn.Linear(2048, classes)

if 'dense' in arch:
    if '121' in arch:
        # (classifier): Linear(in_features=1024)
        model.classifier = nn.Linear(1024, classes)
    elif '169' in arch:
        # (classifier): Linear(in_features=1664)
        model.classifier = nn.Linear(1664, classes)

第二步 选择模型所有层/最后一层进行反向传播优化

迁移学习有两种模型:一种是对预训练好的模型,再次从头训练,模型的每一层都要重新优化。另一种是只重新训练模型最后一层,其余层的参数固定。第一种方法适用于大多数迁移学习,预训练好的模型是在自然图像上训练的,如果迁移到医学图像上来,那么特征之间的差异很大,这时就选择第一种,重头训练。但假如有一个任务是分别猫和狗这种自然图像,且在ImageNet这个数据集中已经包含的,那么就只优化最后一层参数即可。

####  选择模型所有层都要进行反向传播优化 还是 只优化最优一层  #####
fullretrain = True  # True: 表示所有层都要进行优化,为False: 只优化最后一层

if fullretrain:
    print("=> optimizing all layers")
    for param in model.parameters():
        param.requires_grad = True
    optimizer = optim.Adam(model.parameters(), lr=0.03, weight_decay=1e-4)
        # model.parameters(): 把模型所有参数都传进去
else:
    print("=> optimizing fc/classifier layers")
    optimizer = optim.Adam(model.module.fc.parameters(), lr=0.03, weight_decay=1e-4)
    # model.module.fc.parameters(): 只传最后一个分类层的参数进去
    # 注意: 不同模型,最后一层的名字不一样

我们从代码里面可以发现,两种方法的区别就是优化器(optimizer)接收的参数不一样,第一种方法是把模型的所有参数都传进去,第二种是只传模型最后一层的参数。

迁移学习这部分的代码就讲完了。其余的就跟平时训练模型一样的写法。
如果是做分类实验,经验来看,采用预训练的模型都比你自己从0开始训练的效果好。不信的话,可以自己对比对比。

接下来,对其中的部分细节进行进一步的探讨~~

探讨:如何确定模型最后一层的名字是什么

如上述代码里,在resnet这个模型中,它最后一层叫: fc
在Densenet模型中,最后一层叫: classifier
那我怎么知道它最后一层叫什么呢?

方法一: 查询源代码

最直接的办法就是进源代码里面去查看。
在这里插入图片描述
在这里插入图片描述

方法二: 查询模型的子模块名字

model = densenet169(pretrained=False)
for name in model.named_modules():
    print(name)

在这里插入图片描述
这种方法不仅可以知道模型最后一层的名字。
另外,nn.Linear(1664, classes),这括号里的1664是怎么知道的,这种方法还可以查到模型最后一层的输入节点是1664。

觉得有用点赞,关注,你的鼓励才是我坚持更新的动力♥️♥️♥️♥️

在这里插入图片描述

  • 81
    点赞
  • 272
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 7
    评论
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Tina姐

我就看看有没有会打赏我

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值