最近在学习迁移学习的基础部分——预训练微调,代码水平有限,也是借此机会边学边实践边总结,希望自己和大家都有所收获。
学习基础:吴恩达老师,小土堆,刘二大人的视频,本次学习参考了以下文章。
pytorch 加载使用预训练模型和 fine tune 模型微调(冻结一部分层)实战_预训练模型锁层-CSDN博客
Pytorch冻结部分层的参数 - 简书 (jianshu.com)
1.建立数据集及数据导入【可以学习b站小土堆的视频】
2.模型构建和修改
一个简单的模型构建,如:
class model_name(nn.Module):
def __init__(self):
super(model_name, self).__init__()
self.model = Sequential(
Conv2d(3, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10)
)
def forward(self, x):
x = self.model(x)
return x
想要利用现有的模型,以及预训练好的参数,需要修改模型:
- 基于pytorch中的vgg16
- 我们的分类数 通常和参考的预训练模型的分类数 不同,所以将vgg16原来的分类层清空
- 在此基础上,定义我们自己的网络
class model1(nn.Module):
def __init__(self, num_classes=5):
super(model1, self).__init__()
net = models.vgg16(weights='IMAGENET1K_V1') # 加载vgg16的模型和参数
net.classifier = nn.Sequential() # 将vgg16原来的分类层(fc)置空
# 定义我们自己的网络
self.features = net
self.classifier = nn.Sequential( # 新增的网络结构,也就是在vgg16卷积层后面添加的结构
nn.Linear(512*7*7, 512),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(512, 128),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(128, num_classes),
)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1) # 展平
x = self.classifier(x)
return x
3.对网络进行部分的冻结(不更新参数)和部分的更新
# 总共有 44 层网络,37层预加载VGG模型里面的,还有 7 层外面自己加的,我们把前面一些层预加载的模型冻结住,后面的一些层更新
para_optim = []
for i, single_layer in enumerate(model.modules()): # model.modules()能够迭代地遍历模型的所有子层
print(i, single_layer)
if i > 36: # 后面7层更新
for param in single_layer.parameters():
param.requires_grad = True # 更新
para_optim.append(param)
else: # 前面37层冻结
for param in single_layer.parameters():
param.requires_grad = False # 冻结
print(f'para_optim len = {len(para_optim)}')
简单来说,model.modules()可以遍历你的网络结构的每一层 ,这样才可以冻结每一层的参数
关于model.modules()的内容,可以参考以下文章:
4.对于优化器,进行一些设置!!
将需要优化的参数传入优化器,不需要传入的参数过滤掉,所以要用到filter()
函数。
filter()
函数将requires_grad = True的参数传入优化器进行反向传播,requires_grad = False的则被过滤掉。
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR)
其他代码不变
进行 模型修改+参数冻结更新+优化器修改
就可以完成预训练微调了!
pytorch 加载使用预训练模型和 fine tune 模型微调(冻结一部分层)实战_预训练模型锁层-CSDN博客
我主要参考了以上文章进行学习哦!