目录
任务
任务简介
了解transfer learning 与 model finetune
详细介绍
学习模型微调(Finetune)的方法,以及认识Transfer Learning(迁移学习)与Model Finetune之间的关系。
知识点
Transfer Learning & Model Finetune
迁移学习:机器学习分支,研究源域(source domain)的知识如何应用到目标域(target domain)。
模型微调:所谓的模型微调,其实就是模型的迁移学习,在深度学习中,通过不断的迭代,更新卷基层中的权值,这里的权值可以称之为 knowledge , 然后我们可以将这些 knowledge 进行迁移
,主要目的是将这些 knowledge 运用到新的模型
中,这样既可以减小由于数据量不足导致的过拟合现象
,同时又能加快模型的训练速度
具体说来,对于卷积神经网络,我们可以把前面的卷基层,池化层看作是 feature extactor(特征提取)
,是一个非常有共性的部分。得到一系列的feature map。
而后面的全连接层,可以称之为 classifier (分类器)
, 与具体的任务有关。这一部分就需要针对不同的训练任务进行调整,尤其是最后一层
需要根据任务进行相应的调整。
PyTorch中的Finetune
模型微调步骤
- 获取预训练模型参数—原任务中获取得到的知识
- 加载模型(load_state_dict)
- 修改输出层
模型微调训练方法
- 固定预训练的参数
(requires_grad =False;lr=0)
- Features Extractor较小学习率(params_group)
实战-Resnet-18 用于二分类
Resnet-18 模型介绍
蚂蚁蜜蜂二分类数据
训练集:各120~张
验证集:各70~张
Resnet-18模型结构如下图所示:
前面四层是特征提取,接下来四层(layer1~layer4)是残差网络,然后接avgpool池化层,最后接FC分类(原模型是1000分类,ImageNet上训练的)。
迁移结果分析
(1) 直接训练
如果不采用Resnet-18模型进行Finetune,直接对二分类数据进行训练,得到的Loss曲线
损失值一直在0.6附近,并且得到的Accuracy只有70%
(2) 迁移训练,但不冻结卷积层,固定学习率
# ============================ step 2/5 模型 ============================
# 1/3 构建模型
resnet18_ft = models.resnet18()
# 2/3 加载参数
# flag = 0
flag = 1
if flag:
path_pretrained_model = os.path.join(BASEDIR, "..", "..", "data/resnet18-5c106cde.pth")
state_dict_load = torch.load(path_pretrained_model)
resnet18_ft.load_state_dict(state_dict_load)
# 3/3 替换fc层
num_ftrs = resnet18_ft.fc.in_features
resnet18_ft.fc = nn.Linear(num_ftrs, classes)
resnet18_ft.to(device)
可以看出,损失值最后收敛到在0.2附近,并且在第二个Epoch的Accuracy就达到了90%。
(3)迁移训练,冻结卷基层,固定学习率
从代码看,所谓冻结卷积层,是直接把参数的梯度设置为False
for param in resnet18_ft.parameters():
param.requires_grad = False
# ============================ step 2/5 模型 ============================
# 1/3 构建模型
resnet18_ft = models.resnet18()
# 2/3 加载参数
# flag = 0
flag = 1
if flag:
path_pretrained_model = os.path.join(BASEDIR, "..", "..", "data/resnet18-5c106cde.pth")
state_dict_load = torch.load(path_pretrained_model)
resnet18_ft.load_state_dict(state_dict_load)
# 法1 : 冻结卷积层
flag_m1 = 0
# flag_m1 = 1
if flag_m1:
for param in resnet18_ft.parameters():
param.requires_