Pytorch迁移学习

什么是迁移学习?

本文对于迁移学习的概念仅做非常不严谨但十分通俗易懂的简介:
在机器学习中迁移学习即是拿别人的网络或网络参数,加以修改训练成为自己的网络。

迁移学习的两种常见策略

一:加载整个网络结构,但不加载任何参数。

from torchvision import models
Net = models.resnet50(pretrained=Flase)

这种方法比较少用,因为我们通常是不具备那么多的数据集来作训练自己的网络参数的。这里以resnet50为例子,在官网中还可以找到更多的经典网络。

https://pytorch.org/vision/0.8/models.html

二:加载整个网络结构以及预训练模型(官方参数)

from torchvision import models
Net = models.resnet50(pretrained=True)

这是比较常用的,但实际操作中我们经常需要修改一些全连接层之类的。

提取特定的网络层

这里用到了.children(),比如我们需要提取除了最后一层之外的所有层数,然后添加为十分类的全连接层。

import torch.nn as nn
from torchvision import models

Net = models.resnet50(pretrained=True)
news_resnet = nn.Sequential(*list(Net.children())[:-1])
fc = nn.Linear(2048, 10)

比较尴尬的是这种情况需要预先知道它原来的全连接层输出是多少,所以网上更倾向于使用这种方法。

import torch.nn as nn
from torchvision import models

Net = models.resnet50(pretrained=True)
class = Net.fc.in_features
Net.fc = nn.Linear(class, 10)

但我个人对于Net.fc.in_features,是看了好久才知道啥意思,所以非常不喜欢这种灵活性比较低的东西,就推荐下面这个方式。

import torch.nn as nn
from torchvision import models

Net = models.resnet50(pretrained=True)
for layer in Net.children():
	print(layer) #看一下resnet50里面到底长什么样,然后再修改对应的网络层。

冻结特定的网络结构

很多时候,我们并不需要对整个网络参数进行调整,因为官方已经训练的差不多了,我们只要微调最后几个层,适应我们的任务便足矣。

import torch.nn as nn
from torchvision import models

Net = models.resnet50(pretrained=True)
news_resnet = nn.Sequential(*list(Net.children())[:-1])

for p in news_resnet.parameters():
	p.requries_grad = Flase   #冻结news_resnet的所有参数,注意,最后一层在该语句的后面,所以没有被冻结。
	
fc = nn.Linear(2048, 10)

讲到这里主要的内容也就讲完了,接下来会附加全部代码以及一些相关知识博客。

import torch.nn as nn
from torchvision import models

class test_net(nn.Moudle):
    def __init__(self):
		super(Net, self).__init__()
		Net = models.resnet50(pretrained=True)
		self.news_resnet = nn.Sequential(*list(Net.children())[:-1])

		for p in self.parameters():
			p.requries_grad = Flase   #冻结news_resnet的所有参数,注意,最后一层在该语句的后面,所以没有被冻结。
			
		self.fc = nn.Linear(2048, 10)
	
	def forward(image):
		feat_one = self.news_resnet(image)
		all = feat_one.view(feat_one.shape[0], -1)
		out = self.fc(all)
		
		return out

自己搭建的网络的迁移学习

这是一个比较玄虚的问题了,例如我现在按照官方的表格搭建了resnet18。(预设我已经搭建好了)

class Net(nn.Module):
	def __init__(self):
		super(Net, self).__init__()
		self.net = resnet18()  #这里的resnet18是我预先搭建好的一个类
		self.net.load_state_dict(pretrain_path,Flase)#与训练模型的权重
		self.resnet = nn.Sequential(*list(self.net.children())[:-1])#剔除最后一层
		for p in self.resnet.parameters():
			p.requires_grad = Flase
		self.fc = nn.Linear(512,10)#换成十分类
	def forward(self, img):
		feat = self.resnet(img)
		out = self.fc(feat)
		return out

这里需要特别说明的net.load_state_dict(state_dict,strict)有两个参数,第二个参数strict意为是否要符合标准代码。(官方给出的代码,改代码用以训练我们的权重文件),但我们平时手打的代码,虽然实现一样,但多少有点命名之类的不一样,如果没有把strict设置为F很容易出现报错。最后一点需要注意的是,自己手打的代码即便加上官方的权重文件,效果也并不是十分理想,至于为什么本人也在探索之中。

parameters和state_dict的区别

看其他博主吧

https://blog.csdn.net/qq_33590958/article/details/103544175?ops_request_misc=%25257B%252522request%25255Fid%252522%25253A%252522161354954816780274193866%252522%25252C%252522scm%252522%25253A%25252220140713.130102334.pc%25255Fall.%252522%25257D&request_id=161354954816780274193866&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2allfirst_rank_v2~rank_v29-1-103544175.first_rank_v2_pc_rank_v29&utm_term=parameter%25E5%2592%258Cstate_dict%2528%2529%25E7%259A%2584%25E5%258C%25BA%25E5%2588%25AB

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

__TAT__

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值