pytorch之模型的保存、加载与修改

#介绍内容如下:
# 1)加载现有的网络模型;
# 2)对网络模型进行修改;
# 3)保存网络模型以及下一次如何对保存的网络进行读取
import torch
import torchvision
from torch import nn
# 加载现有的网络模型
vgg16_1=torchvision.models.vgg16(pretrained=True)
vgg16_2=torchvision.models.vgg16(pretrained=False)
# print(vgg16_1)

# 对网络模型进行修改
#  修改1:在最后增加一个线性层
vgg16_1.add_module('Add_Linear',nn.Linear(1000,10))
# print(vgg16_1)

# 修改2:在classifier的最后增加一个线性层并且删除之前的
del vgg16_1.Add_Linear
vgg16_1.classifier.add_module('Add_Linear',nn.Linear(1000,10))
# print(vgg16_1)
del vgg16_1.classifier.Add_Linear

# 修改3:对某一层进行修改
#在不知道的情况下可以先获得其输入的维度
input_feature=vgg16_1.classifier[6].in_features
vgg16_1.classifier[6]=nn.Linear(input_feature,10)
# print(vgg16_1)

# 对网络中的模型进行保存
# 模型保存与读取方法1,模型的结构和参数都会被保存
torch.save(vgg16_1,'vgg16_1.pth')
model1=torch.load('vgg16_1.pth')
# print(model1)

# 保存方式2,只保存模型的参数
torch.save(vgg16_2.state_dict(),'vgg16_2.pth')
model2=torch.load('vgg16_2.pth')
# print(model2)

# 由输出可以看到,只保存了模型的参数,因此加载方式如下
model3=torchvision.models.vgg16(pretrained=False)
model3.load_state_dict(torch.load('vgg16_2.pth'))
# print(model3)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值