pytorch调用预训练模型

最近刚开始入手pytorch,搭网络要比tensorflow更容易,有很多预训练好的模型,直接调用即可。
参考链接

import torch
import torchvision.models as models #预训练模型都在这里面
#调用alexnet模型,pretrained=True表示读取网络结构和预训练模型,False表示只加载网络结构,不需要预训练模型
alexnet = models.alexnet(pretrained=False)  #只加载结构

print(alexnet)  # 打印模型结构
# 下面是两种加载模型参数的方式:
model.load_state_dict(model_zoo.load_url(model_urls['alexnet'])) #下载网上的预训练模型
alexnet.load_state_dict(torch.load('F:/DeepLearning/alexnet-owt-4df8aa71.pth'))#加载预先下载好的预训练参数到alexnet

print(alexnet)  # 打印的还是模型结构
pre_dict = alexnet.state_dict()  # 按键值对将模型参数加载到pre_dict
print((k,v) for k ,v in pre_dict.items())  # 打印模型参数
for k ,v in pre_dict.items():  #打印模型每层命名
    print(k)

pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
#note:model是自己定义好的模型,将pretrained_dict和model_dict中命名一致的层加入pretrained_dict(包括参数)

VGG模型同理:

vgg16 = models.vgg16(pretrained=True) #加载网络结构和预训练模型
#static_dict()返回包含模块所有状态的字典
pretrained_dict = vgg16.state_dict()  #返回内置预训练vgg模块的字典
model_dict = model.state_dict()  #返回我们自己model的字典

#------------------------最关键的三步------------------------------------------
# 1. filter out unnecessary keys,也就是说从内置模块中删除掉我们不需要的字典
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}

# 2. overwrite entries in the existing state dict,利用pretrained_dict更新现有的model_dict
model_dict.update(pretrained_dict)

# 3. load the new state dict,更新模型,加载我们真正需要的state_dict
model.load_state_dict(model_dict)

保存和加载模型可以参考链接

还有一个很好的博客 https://blog.csdn.net/u014380165/article/details/79119664

 
------2019.05.07更新----------
加载inception-v3模型并重写fc层:

import torch
import torchvision.models as models
class Model(nn.Module):
	def __init__(self, args):
		super(Model, self).__init__()
		self.args = args		
		if self.args.backbone == 'inception_v3':
			self.model = models.inception_v3()
			if self.args.pretrained:
				self.model.load_state_dict(
					torch.load(os.path.join(self.args.pretrained_models_dir + '.pth')))

			# del self.model._modules['AuxLogits'] #删除AuxLogits模块
			self.model.AuxLogits.fc = nn.Linear(self.model.AuxLogits.fc.in_features, self.args.n_classes) #将模型AuxLogits模块的fc输出通道数改成我们需要的分类数
			print(self.model) #打印模型结构
			print(self.model._modules.keys())  #可以打印出模型的所有模块名称
			self.features = nn.Sequential(*list(self.model.children())[:-1], ) #去掉最后一层fc层,这句也可以写成# del self.model._modules['fc']
			self.last_node_num = 2048
			
		self.avg_pool = nn.AdaptiveAvgPool2d(output_size=(1, 1))  #全局池化
		self.classifier = nn.Linear(self.last_node_num, self.args.n_classes)  #最后加了一个全连接层
		
	def forward(self, x):  #重写forward函数,把几个模块组合起来
		x = self.features(x)
		x = self.avg_pool(x)
		x = x.view(-1, x.shape[1] * x.shape[2] * x.shape[3])
		x = self.classifier(x)
		return x

2019.6.18更新
PyTorch学习:加载模型和参数 https://blog.csdn.net/lscelory/article/details/81482586

Pytorch保存和加载参数有两种方式:(参考

  1. 只保存参数,不保存模型,对应的保存和加载方法分别是:
## 保存训练完的网络的各层参数,保存训练完的网络的各层参数(即weights和bias)
torch.save(net.state_dict(),path)
##加载保存到path中的各层参数到神经网络net2,不可以直接为torch.load_state_dict(path),此函数不能直接接收字符串类型参数
net2.load_state_dict(torch.load(path))
  1. 保存模型和参数,对应的保存和加载方法分别是:
## 保存训练完的整个网络模型(不止weights和bias,还有模型本身)
torch.save(net,path)
## 加载保存到path中的整个神经网络
net2=torch.load(path)

官方推荐方式一,原因自然是保存的内容少,速度会更快。

  • 6
    点赞
  • 67
    收藏
    觉得还不错? 一键收藏
  • 5
    评论
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值