文章目录
前言
本文主要提供了几个常用处理Module的写法:
1)加载torchvision中除全连接层的权重;
2)删除torchvision中ResNet的全连接层并添加新的全连接层;
3)冻结ResNet的layer2前模型的参数。
当然,若想理解背后深层原理,欢迎阅读:
1、torchvision加载ResNet除全连接层的权重
import torch
import torch.nn as nn
import torchvision
#
class ResNet(nn.Module):
def __init__(self):
super(ResNet, self).__init__()
# 这里定义自己的ResNet网络
# 往ResNet里面添加权重
def init_weights(self, pretrained = True):
"""
Args:
self: 模型本身
pretrained (bool)
"""
if pretrained == True:
# 获取ResNet34的预训练权重
resnet34 = torchvision.models.resnet34(pretrained=True)
pretrained_dict = resnet34.state_dict()
"""加载torchvision中的预训练模型和参数后通过state_dict()方法提取参数
也可以直接从官方model_zoo下载:
pretrained_dict = model_zoo.load_url(model_urls['resnet152'])"""
# 获取当前模型的参数字典
model_dict = self.state_dict()
# 将pretrained_dict里不属于model_dict的键剔除掉
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 更新现有的model_dict
model_dict.update(pretrained_dict)
# 加载我们真正需要的state_dict
self.load_state_dict(model_dict)
print('成功加载预训练权重')
else:
pass
if __name__ == '__main__':
resnet = ResNet()
resnet.init_weights(pretrained=True)
其实就是对齐权重字典之间的key即可。
2、torchvision中删除并添加ResNet的全连接层
#导入原始的训练好的ResNet34
resnet34 = torchvision.models.resnet34(pretrained=True)
#删除fc层
del resnet34.fc
#换一个新的全连接层
resnet34.add_module('fc',nn.Linear(2048,2))
当然自己定义的ResNet类可以随意更改啦。
3、冻结torchvision中ResNet前layer2的参数
#导入原始的训练好的ResNet34
resnet34 = torchvision.models.resnet34(pretrained=True)
#若想冻结所有参数
# for params in resnet34.parameters():
# resnet34.eval() # 由于有BN层,eval使得BN使用全局均值和方差
# params.requires_grad = False
# 冻结包括layer1之前的所有模块
for name,module in resnet34.named_children():
# 若没到layer2就一直冻结
if name != 'layer2' :
for p in module.parameters():
p.requires_grad = False
==# 遍历当前module下所有子module,将BN进入eval状态。==
for submodule in module.modules():
if isinstance(submodule, nn.BatchNorm2d):
submodule.eval()
else:
break
#可以查看下各个参数梯度是否变成FALSE了
for name,params in resnet34.named_parameters():
print(name,':',params.requires_grad)
这里冻结了所有可学习参数,但由于有BN层,最好将其设置为eval状态,使用训练集的全局均值和方差。当然,若想训练该resnet34,需要重写train方法。这里详细内容可参考:
MMdet解读ResNet
总结
本文主要介绍了一些易忘且常用的操作,后续会逐渐进行补充。