b站小土堆pytorch学习记录—— P25-P26 网络模型的使用和修改、保存和读取

一、修改

1.方法

add_module(name: str, module: Module) -> None

name 是要添加的子模块的名称。
module 是要添加的子模块。
调用 add_module 方法会向当前模块中添加一个子模块,并使用指定的名称进行标识。

2.代码

import torchvision
from torch import nn

# 实例化一个未经过预训练的 VGG16 模型
vgg16_false = torchvision.models.vgg16(pretrained=False)

# 实例化一个经过预训练的 VGG16 模型
vgg16_true = torchvision.models.vgg16(pretrained=True)

print("ok")

# 输出经过预训练的 VGG16 模型及修改后的模型
print(vgg16_true)
vgg16_true.classifier.add_module("add_linear", nn.Linear(1000, 10))
print(vgg16_true)

# 输出未经过预训练的 VGG16 模型及修改后的模型
print(vgg16_false)
vgg16_false.classifier[6] = nn.Linear(4096, 10)
print(vgg16_false)

修改前的vgg16_true:

在这里插入图片描述
修改后的vgg16_true:

在这里插入图片描述

修改前的vgg16_true:

在这里插入图片描述

修改后的vgg16_true:

在这里插入图片描述

二、保存和读取

1.方法

保存: torch.save(要保存的模型,“文件路径”)

加载: torch.load(“文件路径”)

2.代码

(1)保存

import torch
import torchvision

vgg16 = torchvision.models.vgg16(pretrained=False)

# 保存方式1:模型结构+模型参数
torch.save(vgg16, "vgg16_module1.pth")

# 保存方式2:模型参数(官方推荐)
torch.save(vgg16.state_dict(), "vgg16_module2.pth")

(2)加载

import torch
import torchvision

# 方式1 加载模型
module1 = torch.load("vgg16_module1.pth")
print(module1)

#
module2 = torch.load("vgg16_module2.pth")
print(module2)

# 方式2 加载模型
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_module2.pth"))
print(vgg16)

运行加载的代码后,打印结果如下

module1:

在这里插入图片描述
module2:

在这里插入图片描述

vgg16:

在这里插入图片描述

可以看到,第二种方式保存的数据,加载后是向量形式,需要通过别的方法加载为模型

3.陷阱

第一种方式加载,在某些条件下可能会报错

例如:

假设自定义一个神经网络,保存:

import torch
import torchvision
from torch import nn

# 陷阱
class Guodong(nn.Module):
    def __init__(self):
        super(Guodong,self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)

    def forward(self,x):
        x = self.conv1(x)
        return x


guodong = Guodong()
torch.save(guodong,"guodong_method1.pth")

在另一个文件中加载:

import torch

# 陷阱
module = torch.load("guodong_method1.pth")
print(module)

就会报错:

AttributeError: Can’t get attribute ‘Guodong’ on <module ‘main’ from ‘E:\deepLearning\Pycharm\pytroch_project\theFirstFile\module_load.py’>

解决办法:

(1)把Guodong类放在这个文件里

import torch
from torch import nn
import torchvision

class Guodong(nn.Module):
    def __init__(self):
        super(Guodong,self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)

    def forward(self,x):
        x = self.conv1(x)
        return x

# 陷阱
module = torch.load("guodong_method1.pth")
print(module)

(2)from module_save import *

(module_save)是保存自定义模型的文件

from module_save import *

# 陷阱
module = torch.load("guodong_method1.pth")
print(module)

  • 9
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

云霄星乖乖的果冻

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值