torchvision加载ResNet除全连接层的权重

前言

 本文主要提供了几个常用处理Module的写法:
1)加载torchvision中除全连接层的权重;
 2)删除torchvision中ResNet的全连接层并添加新的全连接层;
 3)冻结ResNet的layer2前模型的参数。

 当然,若想理解背后深层原理,欢迎阅读:

nn.Module源码介绍
nn.Module冻结参数

1、torchvision加载ResNet除全连接层的权重

import torch
import torch.nn as nn
import torchvision
#
class ResNet(nn.Module):
    def __init__(self):
        super(ResNet, self).__init__()
        # 这里定义自己的ResNet网络
    # 往ResNet里面添加权重
    def init_weights(self, pretrained = True):
        """
        Args:
            self: 模型本身
            pretrained (bool)
        """
        if pretrained == True:
            # 获取ResNet34的预训练权重
            resnet34 = torchvision.models.resnet34(pretrained=True)
            pretrained_dict = resnet34.state_dict()
            """加载torchvision中的预训练模型和参数后通过state_dict()方法提取参数
               也可以直接从官方model_zoo下载:
               pretrained_dict = model_zoo.load_url(model_urls['resnet152'])"""
            # 获取当前模型的参数字典
            model_dict = self.state_dict()
            # 将pretrained_dict里不属于model_dict的键剔除掉
            pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
            # 更新现有的model_dict
            model_dict.update(pretrained_dict)
            # 加载我们真正需要的state_dict
            self.load_state_dict(model_dict)
            print('成功加载预训练权重')
        else:
            pass
    
if __name__ == '__main__':
    resnet = ResNet()
    resnet.init_weights(pretrained=True)

 其实就是对齐权重字典之间的key即可。

2、torchvision中删除并添加ResNet的全连接层

#导入原始的训练好的ResNet34
resnet34 = torchvision.models.resnet34(pretrained=True)
#删除fc层
del resnet34.fc
#换一个新的全连接层
resnet34.add_module('fc',nn.Linear(2048,2))

  当然自己定义的ResNet类可以随意更改啦。

3、冻结torchvision中ResNet前layer2的参数

#导入原始的训练好的ResNet34
resnet34 = torchvision.models.resnet34(pretrained=True)
#若想冻结所有参数
# for params in resnet34.parameters():
#     resnet34.eval() # 由于有BN层,eval使得BN使用全局均值和方差
#     params.requires_grad = False
# 冻结包括layer1之前的所有模块
for name,module in resnet34.named_children():
    # 若没到layer2就一直冻结
    if name != 'layer2' :
        for p in module.parameters():
            p.requires_grad = False
        ==# 遍历当前module下所有子module,将BN进入eval状态。==
        for submodule in module.modules():
            if isinstance(submodule, nn.BatchNorm2d):
                submodule.eval()
    else:
        break
#可以查看下各个参数梯度是否变成FALSE了
for name,params in resnet34.named_parameters():
    print(name,':',params.requires_grad)

 这里冻结了所有可学习参数,但由于有BN层,最好将其设置为eval状态,使用训练集的全局均值和方差。当然,若想训练该resnet34,需要重写train方法。这里详细内容可参考:
MMdet解读ResNet

总结

 本文主要介绍了一些易忘且常用的操作,后续会逐渐进行补充。若有问题欢迎+vx:wulele2541612007,拉你进群探讨交流。

  • 7
    点赞
  • 34
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值