#介绍内容如下:
# 1)加载现有的网络模型;
# 2)对网络模型进行修改;
# 3)保存网络模型以及下一次如何对保存的网络进行读取
import torch
import torchvision
from torch import nn
# 加载现有的网络模型
vgg16_1=torchvision.models.vgg16(pretrained=True)
vgg16_2=torchvision.models.vgg16(pretrained=False)
# print(vgg16_1)
# 对网络模型进行修改
# 修改1:在最后增加一个线性层
vgg16_1.add_module('Add_Linear',nn.Linear(1000,10))
# print(vgg16_1)
# 修改2:在classifier的最后增加一个线性层并且删除之前的
del vgg16_1.Add_Linear
vgg16_1.classifier.add_module('Add_Linear',nn.Linear(1000,10))
# print(vgg16_1)
del vgg16_1.classifier.Add_Linear
# 修改3:对某一层进行修改
#在不知道的情况下可以先获得其输入的维度
input_feature=vgg16_1.classifier[6].in_features
vgg16_1.classifier[6]=nn.Linear(input_feature,10)
# print(vgg16_1)
# 对网络中的模型进行保存
# 模型保存与读取方法1,模型的结构和参数都会被保存
torch.save(vgg16_1,'vgg16_1.pth')
model1=torch.load('vgg16_1.pth')
# print(model1)
# 保存方式2,只保存模型的参数
torch.save(vgg16_2.state_dict(),'vgg16_2.pth')
model2=torch.load('vgg16_2.pth')
# print(model2)
# 由输出可以看到,只保存了模型的参数,因此加载方式如下
model3=torchvision.models.vgg16(pretrained=False)
model3.load_state_dict(torch.load('vgg16_2.pth'))
# print(model3)
pytorch之模型的保存、加载与修改
最新推荐文章于 2024-04-19 10:33:16 发布