PyTorch学习笔记(1)-finetune网络的一些注意事项

1.Transfer Learning 

transfer learning 是十分重要的技术,尤其是在实际应用中,往往数据很大但是有标签的(即训练样本)样本是十分稀少的,我们对数据进行标注十分耗时耗力的事情。这个时候就需要使用迁移学习,通过不同的任务对网络进行finetune。

transfer learning 有以下几种场景:

1. 将ConvNet 作为特征提取器:将预训练好的网络去掉最后一层,前面的层作为特征提取器,在新的数据集上训练一个线性分类器或者是SVM。

2.finetune网络:在预训练好的网络上继续进行训练,可以是所有层,也可以freeze几层。

3.使用别人训练好的权重:训练往往需要大量的时间,我们可以在网上下载别人训练好的权重文件进行使用。

transfer learning 推荐阅读材料:http://cs231n.github.io/transfer-learning/


几种实际情况分析:

 新数据集小新数据集大
与original数据集相似训练一个线性分类器on CNN codes(1)finetune网络(2)
与original数据集不同在网络更早的层训练一个SVM分类器

可以不使用transfer,

也可以作为参数的初始化

finefune整个网络

2.Transfer Learning in PyTorch

在Pytorch中进行transfer learning可以查看tutorial,里面有详细的过程。(有时间会补充对其的整理分析)

https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html

3.some problems

这部分是我的一些笔记总结,如果有相同疑问的同学希望可以帮助你。

私以为requires_grad 参数对于finetune网络是十分重要,在PyTorch中我们只需要设置这一变量的状态设置成False就足够去freeze网络的某一部分。

当我们使用torchvision.models 中的模型时:

(注:torchvision里面包括了一些流行的数据集、常用的网络结构和一些常用的计算机视觉图像处理的方法)

方式1:

import torchvision
vgg_model = torchvision.models.vgg16(pretrained=True)

直接将参数设置为True,这样就会在网上下载提供的参数文件。

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" 

问题:往往会遇到下载速度慢,甚至下载失败的情况。

方式2:

在网上下载对应的vgg16.py的文件和预训练好的权重文件,加载进去。

vgg16.py

class Vgg16(torch.nn.Module):
    def __init__(self):
        super(Vgg16, self).__init__()
        self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)

        self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)

        self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)

        self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
        self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
        self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)

        self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
        self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
        self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)

    def forward(self, X):
        h = F.relu(self.conv1_1(X))
        h = F.relu(self.conv1_2(h))
        relu1_2 = h
        h = F.max_pool2d(h, kernel_size=2, stride=2)

        h = F.relu(self.conv2_1(h))
        h = F.relu(self.conv2_2(h))
        relu2_2 = h
        h = F.max_pool2d(h, kernel_size=2, stride=2)

        h = F.relu(self.conv3_1(h))
        h = F.relu(self.conv3_2(h))
        h = F.relu(self.conv3_3(h))
        relu3_3 = h
        h = F.max_pool2d(h, kernel_size=2, stride=2)

        h = F.relu(self.conv4_1(h))
        h = F.relu(self.conv4_2(h))
        h = F.relu(self.conv4_3(h))
        relu4_3 = h

        return [relu1_2, relu2_2, relu3_3, relu4_3]

下载的预训练权重为vgg16.weight,则加载现有参数模型的代码如下:

from vgg16 import Vgg16
import torch
vgg_model = Vgg16()
vgg_model.load_state_dict(torch.load("vgg16.weight"))
for params in vgg_model.parameters():
    params.requires_grad = False

到此,我就想我可以不可以使用torchvision.models里面的模型,然后自己在网上下载权重文件,这时就出现了

KeyError: 'unexpected key "XXXX" in state_dict'

 xxxx处可以是很多值,根据你自己的程序来看。

这类错误主要是由于模型参数变量名和下载的权重文件中的不一致所导致的,因为两者都是dict 类型,key-value相互对应,当key不一致时,就会出现这种情况。

为此解决此类问题,就变成了解决两个dict 的key不一致问题,我们可以将下载的权重文件中的key进行适当的修改,(可以先查看一下,差异大小,如果很大,可能下载的文件就和网络结构不匹配,那就直接不要用了)。

方式3:

在网上下载对应的文件。

import torchvision
import torch.nn as nn
import torch.optim as optim
import torch
#state_dict = torch.utils.model_zoo.load_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')
model = torchvision.models.resnet18() #加载一个在ImageNet上预训练的模型
model.load_state_dict(torch.load("resnet18-5c106cde.pth"))
for param in model.parameters():
    param.requires_grad = False
model.fc = nn.Linear(512,10) #torch.nn 中的Linear layer  :in_fearture_num out_fearture_num bias(bool)
optimizer = optim.SGD(model.fc.parameters(),lr = 1e-2,momentum = 0.9)

4.Addition

保存模型:

import torch 
torch.save(model.state_dict(),path)

加载全模型:

上面已提到。

加载模型部分参数:

以上面vgg16.weight那个为例,加载前n层

import torch
trained_vgg_params = torch.load("vgg16.weight")
layer_name = [name for name in trained_vgg_params.keys()]
#print(layer_name)
n = 3
for i,param in enumerate(vgg_model.parameters()):
    #print(i,param)
    param.data = trained_vgg_params[layer_name[i]]
    param.requires_grad = False
    if i > n:
        break





                
  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值