已有网络模型的加载与保存
对于我们要保存的网络模型,有两种保存方法
- 保存网络模型的结构与参数
- 仅保留网络模型的参数
import torch
import torchvision
from torch import nn
# 训练好的vgg网络
vgg_train = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.IMAGENET1K_V1)
# vgg_test = torchvision.models.vgg16(pretrained=True)
# 未训练的vgg网络,只有网络结构
vgg_not_train = torchvision.models.vgg16(weights=None)
# vgg_not_train = torchvision.models.vgg16(pretrained=True)
# 保存神经网络,保存所有信息(结构+参数)
torch.save(vgg_train, 'vgg16_method1.pth')
# 保存神经网络参数
torch.save(vgg_train.state_dict(), 'vgg16_method2.pth')
print(vgg_train.state_dict())
当加载两种保存的网络时也会有些许差异
import torch
import torchvision
# 加载模型+参数
vgg16_first = torch.load("vgg16_method1.pth")
# 仅加载参数
vgg16_second = torchvision.models.vgg16(weights='None')
# 将参数加载到我们的空白模型中
vgg16_second.load_state_dict(torch.load("vgg16_method2.pth"))
已有网络模型的添加与修改
我们导入的vgg16
的网络模型结构为
我们如果想添加一层add_Linear
作用是将分类为1000的网络转化为10分类
# 已训练好的网络的添加
vgg_train.classifier.add_module('add_Linear',nn.Linear(1000,10))
如果想要修改
# 已训练好的网络的修改
vgg_not_train.classifier[6] = nn.Linear(4096, 10)
print(vgg_not_train)
自定义网络模型的加载
首先我们先自定义网络结构,并保存为.pth文件
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 64, 3)
def forward(self, x):
x = self.conv1(x)
return x
net = Net()
# 保存自定义网络
torch.save(net, 'user_define_net_save.pth')
这里需要注意,我们在新的文件加载这个网络模型时,不能直接通过
model = torch.load('user_define_net_save.pth')
进行加载,而是要先引入我们网络模型的类
import torch
from torch import nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 64, 3)
def forward(self, x):
x = self.conv1(x)
return x
model = Net()
# 注意无法直接通过这条语句导入,需要先引入网络定义
model = torch.load('user_define_net_save.pth')
或者通过
from user_define_net_save import *
来导入我们的类信息
import torch
from torch import nn
from user_define_net_save import *
model = Net()
# 注意无法直接通过这条语句导入,需要先引入网络定义
model = torch.load('user_define_net_save.pth')