【Pytorch】12.网络模型的加载、修改与保存

已有网络模型的加载与保存

对于我们要保存的网络模型,有两种保存方法

  • 保存网络模型的结构与参数
  • 仅保留网络模型的参数
import torch
import torchvision
from torch import nn

# 训练好的vgg网络
vgg_train = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.IMAGENET1K_V1)
# vgg_test = torchvision.models.vgg16(pretrained=True)


# 未训练的vgg网络,只有网络结构
vgg_not_train = torchvision.models.vgg16(weights=None)
# vgg_not_train = torchvision.models.vgg16(pretrained=True)

# 保存神经网络,保存所有信息(结构+参数)
torch.save(vgg_train, 'vgg16_method1.pth')
# 保存神经网络参数
torch.save(vgg_train.state_dict(), 'vgg16_method2.pth')

print(vgg_train.state_dict())

当加载两种保存的网络时也会有些许差异

import torch
import torchvision

# 加载模型+参数
vgg16_first = torch.load("vgg16_method1.pth")

# 仅加载参数
vgg16_second = torchvision.models.vgg16(weights='None')
# 将参数加载到我们的空白模型中
vgg16_second.load_state_dict(torch.load("vgg16_method2.pth"))

已有网络模型的添加与修改

我们导入的vgg16的网络模型结构为
在这里插入图片描述
我们如果想添加一层add_Linear作用是将分类为1000的网络转化为10分类

# 已训练好的网络的添加
vgg_train.classifier.add_module('add_Linear',nn.Linear(1000,10))

在这里插入图片描述

如果想要修改

# 已训练好的网络的修改
vgg_not_train.classifier[6] = nn.Linear(4096, 10)
print(vgg_not_train)

在这里插入图片描述

自定义网络模型的加载

首先我们先自定义网络结构,并保存为.pth文件

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3)

    def forward(self, x):
        x = self.conv1(x)
        return x
net = Net()
# 保存自定义网络
torch.save(net, 'user_define_net_save.pth')

这里需要注意,我们在新的文件加载这个网络模型时,不能直接通过

model = torch.load('user_define_net_save.pth')

进行加载,而是要先引入我们网络模型的类

import torch
from torch import nn


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3)

    def forward(self, x):
        x = self.conv1(x)
        return x


model = Net()
# 注意无法直接通过这条语句导入,需要先引入网络定义
model = torch.load('user_define_net_save.pth')

或者通过

from user_define_net_save import *

来导入我们的类信息

import torch
from torch import nn
from user_define_net_save import *


model = Net()
# 注意无法直接通过这条语句导入,需要先引入网络定义
model = torch.load('user_define_net_save.pth')
  • 6
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值