目前Transformer在CV届已经大杀四方,各个赛事上都取得了SOAT的水平。很多人也想着将各种Transformer-based的backbone拿来用,但是很多backbone都是需要加载预训练参数才能使用,所以我们需要将网络进行迁移学习,才能拿来使用。我们通常会在网络刚开始训练的时候冻结除分类层之外的参数,在训练一到两轮再解封参数。下面我以swin-Transformer为例,介绍如何进行网络的fine-tune。
一、冻结网络参数
首先需要使用Timm库加载Swin的结构以及参数,Timm库是Ross Wightman大神在github上开源的用于图像分类的库,包含各种各样的CNN以及Transformer-based的backbone,以及预训练参数模型,使用起来非常友好,真顶!例如加载Swin只需要一句话就好了。
from timm.models import create_model
Swin=create_model('swin_large_patch4_window7_224_in22k',pretrained=True)
因为我的分类任务数为14和预训练参数中的不匹配,并且我想要在训练前期固定除了分类层之外的所有参数,所以我加载网络后,会先去掉分类层,然后固定这部分的参数,接着再重新构建分类层。固定参数的代码如下:
for p in self.backbone.parameters():
p.requires_grad = False
整体代码如下:
class classifer(nn.Module):
def __init__(self,in_ch,num_classes):
super().__init__()
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.fc = nn.Linear(in_ch,num_classes)
def forward(self, x):
x = self.avgpool(x.transpose(1, 2)) # B C 1
x = torch.flatten(x, 1)
x = self.fc(x)
return x
class Swin(nn.Module):
def __init__(self):
super().__init__()
#创建模型,并且加载预训练参数
self.swin= create_model('swin_large_patch4_window7_224_in22k',pretrained=True)
#整体模型的结构
pretrained_dict = self.swin.state_dict()
#去除模型的分类层
self.backbone = nn.Sequential(*list(self.swin.children())[:-2])
#去除分类层的模型架构
model_dict = self.backbone.state_dict()
# 将pretrained_dict里不属于model_dict的键剔除掉
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 更新现有的model_dict
model_dict.update(pretrained_dict)
# 加载我们真正需要的state_dict
self.backbone.load_state_dict(model_dict)
#屏蔽除分类层所有的参数
for p in self.backbone.parameters():
p.requires_grad = False
#构建新的分类层
self.head = classifer(1536, 14)
def forward(self, x):
x = self.backbone(x)
x=self.head(x)
return x
除了在模型中屏蔽参数外,还要再优化器中进行屏蔽,需要使用filter进行过滤,就是只优化梯度为True的参数,即分类层
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, cnn.parameters()), lr=1e-4, betas=(0.9, 0.999),weight_decay=1e-6)
二、解冻网络参数
在固定参数训练一轮以后,再解冻backbone部分的参数
if epoch ==1:
for p in Swin.backbone.parameters():
p.requires_grad = True
optimizer.add_param_group({'params': Swin.backbone.parameters()})