PyTorch自学系列之 pretrained model usage

首先为了尊重他人贡献,本博客参考链接为:https://blog.csdn.net/VictoriaW/article/details/72821329 (感谢知识传递者)

楼主使用预训练模型的初衷是:目标识别检测框架在不同的计算机视觉应用中性能的提升很大一部分原因是由于骨干网络的西能提升,为了使用更好的Backbone获取feature map,所以有必要学习如何使用更好的预训练模型。

套路:在pytorch库中,使用预训练模型可以分为两种方式:第一,完全加载网络结构和checkpoint的权重;第二,首先构建相同结构的网络,然后加载预训练模型的checkpoint的权重state。一般而言,使用后者是一个通用操作。

pytorch中的torchvision.models模块中存放了较常用的backbone:vgg11,vgg11_bn,vgg13, vgg13_bn, vgg16,vgg16_bn,vgg19,vgg19_bn; resnet18, resnet34, resnet50, resnet101, resnet152; densnet121, densnet162, densnet169, densnet201等等(不一一列举)

第一步,构建想用使用的预训练网络的结构

可以自己构建,可以直接调用(即从pytorch.org...)中去下载

import torchvision.models as models    # 导入相应模块
       resnet18 = models.resnet18(pretrained=True)
       vgg16 = models.vgg16(pretrained=True)
      alexnet = models.alexnet(pretrained=True)
      squeezenet = models.squeezenet1_0(pretrained=True)

第二步,如果需要可以微调网络结构和筛选参数

       vgg16 = models.vgg16(pretrained=True)
       pretrained_dict = vgg16.state_dict()
       model_dict = model.state_dict()

      # 1. filter out unnecessary keys
      pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
      # 2. overwrite entries in the existing state dict
      model_dict.update(pretrained_dict) 
      # 3. load the new state dict
      model.load_state_dict(model_dict)
 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值