Pytorch 快速入门(七)加载预训练模型初始化网络参数

在预训练网络的基础上,修改部分层得到自己的网络,通常我们需要解决的问题包括: 
1. 从预训练的模型加载参数 

2. 对新网络两部分设置不同的学习率,主要训练自己添加的层 

PyTorch提供的预训练模型

PyTorch定义了几个常用模型,并且提供了预训练版本:

  • AlexNet: AlexNet variant from the “One weird trick” paper.
  • VGG: VGG-11, VGG-13, VGG-16, VGG-19 (with and without batch normalization)
  • ResNet: ResNet-18, ResNet-34, ResNet-50, ResNet-101, ResNet-152
  • SqueezeNet: SqueezeNet 1.0, and SqueezeNet 1.1

预训练模型可以通过设置pretrained=True来构建:

eg:

import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)
vgg16 = models.vgg16(pretrained=True)
alexnet = models.alexnet(pretrained=True)
squeezenet = models.squeezenet1_0(pretrained=True)

预训练模型期望的输入是RGB图像的mini-batch:(batch_size, 3, H, W),并且H和W不能低于224。图像的像素值必须在范围[0,1]间,并且用均值mean=[0.485, 0.456, 0.406]和方差std=[0.229, 0.224, 0.225]进行归一化。

加载预训练模型

加载参数可以参考 apaszke推荐的做法 ,即删除与当前model不匹配的key。

torch.nn.Module对象有函数static_dict()用于返回包含模块所有状态的字典,包括参数和缓存。键是参数名称或者缓存名称。

函数Module::load_state_dict(state_dict)用state_dict中的状态值更新模块的状态值。static_dict中的键应该和函数static_dict()返回的字典中的键完全一样。

下面给出加载预训练的模型的示例:

vgg16 = models.vgg16(pretrained=True)
pretrained_dict = vgg16.state_dict()
model_dict = model.state_dict()

# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict) 
# 3. load the new state dict
model.load_state_dict(model_dict)

不同层设置不同学习率的方法 

此部分主要参考 PyTorch教程的Autograd machnics部分 
在PyTorch中,每个Variable数据含有两个flag(requires_grad和volatile)用于指示是否计算此Variable的梯度。设置requires_grad = False,或者设置volatile=True,即可指示不计算此Variable的梯度
for param in model.parameters():
    param.requires_grad = False
注意,在模型测试时,对input_data设置volatile=True,可以节省测试时的显存




  • 3
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
PyTorch是一个功能强大的机器学习框架。它使用动态计算图和高效的自动微分来加速深度学习。在实际编码的过程中,我们经常会使用预训练模型来加速模型训练和进一步提升模型准确率,不过一些时候我们并不需要整个预训练模型的所有参数来进行训练,而是只需要加载预训练模型的部分参数。那么在PyTorch中,我们要如何来加载预训练模型的部分参数呢? 要想加载预训练模型的部分参数,在PyTorch中,我们可以使用load_state_dict()函数实现。load_state_dict()函数在PyTorch中是将参数拷贝到新模型中的函数,新模型和预训练模型网络结构应该是相同的。然后我们可以通过load_state_dict()函数的参数prefix和exclude来实现部分参数加载。prefix参数是指定了预训练模型中需要加载参数的前缀,而exclude参数是指定了我们不需要加载参数。 例如,我们有一个预训练模型‘resnet18.pth’,它包含了resnet18模型在imagenet上训练好的模型参数。我们想要使用这个模型来进行一些迁移学习,那只需要加载resnet18最后一层fc层之前的所有模型参数,而不需要加载最后一层fc层的权重。那么,我们可以通过以下代码来实现: ``` import torch.utils.model_zoo as model_zoo import torchvision.models as models # 定义一个resnet18模型 resnet18 = models.resnet18(pretrained=False) # 加载预训练模型的所有参数 model_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' resnet18.load_state_dict(model_zoo.load_url(model_url)) # 获取所有要加载参数的名字 params_to_update = [] for name, param in resnet18.named_parameters(): if 'fc' not in name: params_to_update.append(name) # 加载部分预训练模型参数 state_dict = model_zoo.load_url(model_url) model_dict = resnet18.state_dict() for name, value in state_dict.items(): if name.startswith(tuple(params_to_update)): model_dict.update({name: value}) resnet18.load_state_dict(model_dict) ``` 上述代码先是定义了一个resnet18模型,然后加载resnet18预训练模型的所有参数。通过获取所有需要加载参数的名字,然后将其加载到新模型中,从而实现了加载预训练模型的部分参数的目的。 总结: 通过使用load_state_dict()函数的prefix和exclude参数,在PyTorch中实现了对预训练模型的部分参数加载。这将使我们在使用预训练模型时更加灵活和高效。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值