pytorch加载预训练模型

pytorch加载预训练的model

博主也是刚刚开始学习使用深度分类网络实现一些应用、做对比试验,pytorch是大家极力推荐的deep learning框架,与python本身的语言风格很切合,又有一种写matlab的感觉,受到广泛青睐。博主在一篇文章的对比试验中想要与VGGResNet以及XceptionNet在我的分类任务中作对比,之前也是对什么学习框架的应用一窍不通,在查阅了一些材料之后自己总结了一下,分享一下。

我们在训练某个网络时(尤其是一些声名远扬的网络),几乎都会使用预训练模型,如果模型的训练起点比较好,在合适的学习率下会让loss迅速收敛,加快学习的进程,也叫作fine-tuning。

加载预训练模型的本质问题就在于想要使用的参数的模型结构与当前的模型结构不完全一样,最简单的就是最后的全连接层的神经元数量不一样。本文针对的就是这种情况,最直接的应用场景就是例如:想借用1000分类的VGG-19的网络参数作为初始参数做fine-tuning,但是我的分类任务只是使用VGG-19实现简单的二分类,这时候就需要做一些操作来把全连接层前面的参数加载到我的网络中。

VGG-19

model = vgg19(num_classes = 2)                                                   # 初始化一个VGG-19的网络,网络结构函数已经import
if use_gpu:                                                                      # 使用GPU
     model = model.cuda()
     model = torch.nn.DataParallel(model)

pretrained_dict = torch.load("lujing/vgg19-dcbb9e9d.pth")                        # 加载预训练的模型

pretrained_dict = torch.load(args.resume)
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}  # 剔除不匹配的参数
model_dict.update(pretrained_dict)                                               # 模型参数更新
model.load_state_dict(model_dict)

ResNet-50

model = resnet50(num_classes = 2)                                                # 初始化一个ResNet-50的网络,网络结构函数预先定义已经import
if use_gpu:                                                                      # 使用GPU
     model = model.cuda()
     model = torch.nn.DataParallel(model)

pretrained_dict = torch.load("lujing/resnet50-6-classes.pth")                    # 加载预训练的模型,原来的模型是6分类

pretrained_dict = torch.load(args.resume)
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}  # 剔除不匹配的参数
del pretrained_dict["module.fc.weight"]                                          # 删除掉全连接层的权重和偏置参数
del pretrained_dict["module.fc.bias"]
model_dict.update(pretrained_dict)                                               # 模型参数更新
model.load_state_dict(model_dict)

XceptionNet

model = xception(num_classes = 2)                                                # 初始化一个ResNet-50的网络,网络结构函数预先定义并已经import
if use_gpu:                                                                      # 使用GPU
     model = model.cuda()
     model = torch.nn.DataParallel(model)

pretrained_dict = torch.load("lujing/xception-6-classes.pth")                    # 加载预训练的模型,原来的模型是6分类

pretrained_dict = torch.load(args.resume)
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}  # 剔除不匹配的参数
del pretrained_dict["module.fc.weight"]                                          # 删除掉全连接层的权重和偏置参数
del pretrained_dict["module.fc.bias"]
model_dict.update(pretrained_dict)                                               # 模型参数更新
model.load_state_dict(model_dict)

总结

VGG加载时即使待加载的模型与当前的模型参数结构不符,使用

pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 

就可以实现加载,不需要再像ResNet和Xception一样再加一句

del pretrained_dict["module.fc.weight"]
del pretrained_dict["module.fc.bias"]

具体原因我也没有摸清楚,也想请教大神指教,不过这样的方式加载是亲测可用的。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值