什么是迁移学习?
本文对于迁移学习的概念仅做非常不严谨但十分通俗易懂的简介:
在机器学习中迁移学习即是拿别人的网络或网络参数,加以修改训练成为自己的网络。
迁移学习的两种常见策略
一:加载整个网络结构,但不加载任何参数。
from torchvision import models
Net = models.resnet50(pretrained=Flase)
这种方法比较少用,因为我们通常是不具备那么多的数据集来作训练自己的网络参数的。这里以resnet50为例子,在官网中还可以找到更多的经典网络。
https://pytorch.org/vision/0.8/models.html
二:加载整个网络结构以及预训练模型(官方参数)
from torchvision import models
Net = models.resnet50(pretrained=True)
这是比较常用的,但实际操作中我们经常需要修改一些全连接层之类的。
提取特定的网络层
这里用到了.children(),比如我们需要提取除了最后一层之外的所有层数,然后添加为十分类的全连接层。
import torch.nn as nn
from torchvision import models
Net = models.resnet50(pretrained=True)
news_resnet = nn.Sequential(*list(Net.children())[:-1])
fc = nn.Linear(2048, 10)
比较尴尬的是这种情况需要预先知道它原来的全连接层输出是多少,所以网上更倾向于使用这种方法。
import torch.nn as nn
from torchvision import models
Net = models.resnet50(pretrained=True)
class = Net.fc.in_features
Net.fc = nn.Linear(class, 10)
但我个人对于Net.fc.in_features,是看了好久才知道啥意思,所以非常不喜欢这种灵活性比较低的东西,就推荐下面这个方式。
import torch.nn as nn
from torchvision import models
Net = models.resnet50(pretrained=True)
for layer in Net.children():
print(layer) #看一下resnet50里面到底长什么样,然后再修改对应的网络层。
冻结特定的网络结构
很多时候,我们并不需要对整个网络参数进行调整,因为官方已经训练的差不多了,我们只要微调最后几个层,适应我们的任务便足矣。
import torch.nn as nn
from torchvision import models
Net = models.resnet50(pretrained=True)
news_resnet = nn.Sequential(*list(Net.children())[:-1])
for p in news_resnet.parameters():
p.requries_grad = Flase #冻结news_resnet的所有参数,注意,最后一层在该语句的后面,所以没有被冻结。
fc = nn.Linear(2048, 10)
讲到这里主要的内容也就讲完了,接下来会附加全部代码以及一些相关知识博客。
import torch.nn as nn
from torchvision import models
class test_net(nn.Moudle):
def __init__(self):
super(Net, self).__init__()
Net = models.resnet50(pretrained=True)
self.news_resnet = nn.Sequential(*list(Net.children())[:-1])
for p in self.parameters():
p.requries_grad = Flase #冻结news_resnet的所有参数,注意,最后一层在该语句的后面,所以没有被冻结。
self.fc = nn.Linear(2048, 10)
def forward(image):
feat_one = self.news_resnet(image)
all = feat_one.view(feat_one.shape[0], -1)
out = self.fc(all)
return out
自己搭建的网络的迁移学习
这是一个比较玄虚的问题了,例如我现在按照官方的表格搭建了resnet18。(预设我已经搭建好了)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.net = resnet18() #这里的resnet18是我预先搭建好的一个类
self.net.load_state_dict(pretrain_path,Flase)#与训练模型的权重
self.resnet = nn.Sequential(*list(self.net.children())[:-1])#剔除最后一层
for p in self.resnet.parameters():
p.requires_grad = Flase
self.fc = nn.Linear(512,10)#换成十分类
def forward(self, img):
feat = self.resnet(img)
out = self.fc(feat)
return out
这里需要特别说明的net.load_state_dict(state_dict,strict)有两个参数,第二个参数strict意为是否要符合标准代码。(官方给出的代码,改代码用以训练我们的权重文件),但我们平时手打的代码,虽然实现一样,但多少有点命名之类的不一样,如果没有把strict设置为F很容易出现报错。最后一点需要注意的是,自己手打的代码即便加上官方的权重文件,效果也并不是十分理想,至于为什么本人也在探索之中。
parameters和state_dict的区别
看其他博主吧
https://blog.csdn.net/qq_33590958/article/details/103544175?ops_request_misc=%25257B%252522request%25255Fid%252522%25253A%252522161354954816780274193866%252522%25252C%252522scm%252522%25253A%25252220140713.130102334.pc%25255Fall.%252522%25257D&request_id=161354954816780274193866&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2allfirst_rank_v2~rank_v29-1-103544175.first_rank_v2_pc_rank_v29&utm_term=parameter%25E5%2592%258Cstate_dict%2528%2529%25E7%259A%2584%25E5%258C%25BA%25E5%2588%25AB