pytorch使用(五)使用pytorch进行微调(fine-tuning)

pytorch使用:目录


pytorch使用(五)使用pytorch进行微调(fine-tuning)

在使用pytorch的时候,发现使用预训练的模型进行微调的时候有比较难的两步,一是如何加载需要的两部分模型

1. 定义网络并且加载网络参数
  • 首先定义自己模型并且加载预训练网络的模型和参数,定义自己模型的时候把想要用的层名字设置为和预训练模型一样的
  • 加载预训练模型中的参数到自己的模型
# load the pre-trained network
model_zero = C3D()
model_zero.load_state_dict(torch.load(paraPath))

model = ROI_C3D(classes=para['nClass'])#ROI_C3D is my net
model_dict = model.state_dict()

model_zero = {k: v for k, v in model_zero.state_dict().items() if k in model_dict}
model_dict.update(model_zero)
model.load_state_dict(model_dict)
2. 设置学习率

通常预训练层的学习率会低一些. 在下面这个例子中,在定义网络的时候,相比原来的模型,将最后一个全连接的名字改为了classifier

#set optimization method
ignored_params = list(map(id, model.classifier.parameters())) #layer need to be trained
base_params = filter(lambda p: id(p) not in ignored_params,model.parameters())
optimizer = optim.SGD([
    {'params': base_params},
    {'params': model.classifier.parameters(), 'lr': para['lr']*0.1}], 0.001, momentum=0.9, weight_decay=1e-4)

这样预训练的模型学习率是0.0001,而最后一个全连接是0.001

阅读更多
版权声明:转载请附上链接 https://blog.csdn.net/GYGuo95/article/details/79945631
文章标签: pytorch
个人分类: pytorch 深度学习
上一篇pycharm:Updating Indices,无法进行调试和运行
下一篇windows下 命令行+winscp 实现与linux的远程文件传输
想对作者说点什么? 我来说一句

没有更多推荐了,返回首页

关闭
关闭