PyTorch 保存和加载模型、查看模型结构的方法(入门级,不包括保存优化器、只加载部分参数等进阶方法)

本文详细介绍了PyTorch中模型的保存和加载方法,包括完整网络和仅参数的保存。官方推荐只保存网络参数,因体积小且加载速度快。后缀.pth、.pt、.pkl等均可用于模型保存,大小无差异。同时,文章展示了如何打印模型结构和参数。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1、如何保存和加载模型

1.1 保存和加载模型的两种方法

保存模型有两种最基本的方式:

(1)保存整个网络,包括模型结构,参数和项目文件结构

保存整个网络:

torch.save(net, path1) 

加载网络:

device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = torch.load(path1, map_location = device)   
# map_location 参数可选,用于在多卡GPU中训练和推理使用的 CUDA 序号不一样的情况

(2)只保存网络参数

保存网络参数:

torch.save(net.state_dict(),path2)

加载网络参数:

# 模型定义
class MyModel(nn.Module):
	def __init__(self, para): 
		...
	def forward(self, x):
		...

# 需要先加载模型结构
model = MyModel(para = para)

# 再加载网络参数
model.load_state_dict(torch.load(path2, map_location = device))   # map_location 参数可选

方法二(只保存模型参数)是官方推荐的方法,运行速度快,且占空间较小。需要注意的是 net.state_dict() 是将网络参数保存为字典形式(OrderedDict),load_state_dict() 加载的并不是网络参数的pth文件,而是字典。

1.2 代码示例

代码示例,一看就会:

import torchvision
import torch

# 加载torchvision中自带的vgg16网络
vgg16 = torchvision.models.vgg16(pretrained=False)

# 保存模型
torch.save(vgg16, "vgg16.pth")

# 加载模型
model=torch.load("vgg16.pth")

# 打印模型
print(model)


# 保存模型参数
torch.save(vgg16.state_dict(),"vgg16_para.pth")

##################################################
# 不建议这样写,虽然可以运行,但会报错
# 加载模型参数
vgg16_para=torch.load("vgg16_para.pth")

# 加载字典
vgg16.load_state_dict(torch.load(vgg16_para))
##################################################

# 应该这样写,不会有问题
# 加载模型,前提是vgg16的模型结构已经定义好了
vgg16.load_state_dict(torch.load("vgg16_para.pth"))

model=vgg16
print(model)

# 打印模型参数
print(vgg16_para)

1.3 注意:建议只保存模型参数

强烈建议只保存模型参数,而非保存整个网络。PyTorch 官方也是这么建议的。

PyTorch 如果只保存模型权重。那么只是储存一个普通的字典。如果保存整个网络(模型结构 + 参数,PyTorch 是以 pickle 序列化格式格式保存的。其中除了模型,还保存了生成模型的项目文件名称和路径等信息,加载模型时候 pickle 反序列化是必须要路径合代码完全一致的。

也就是说如果保存整个网络,将生成的 .pt 模型文件移动到其他项目中是用不了的,会报错 No module named ‘xxx’ ,除非新项目与原项目的文件完全相同。

2、后缀问题

保存模型的后缀有 .pth、.pt、.pkl、.ckpt 等多种格式,这些后缀都可以使用且没有什么区别,保存的模型大小也一样。以下写法没有区别:

torch.save(vgg16, "vgg16.pt")
torch.save(vgg16,"vgg16.ckpt")
torch.save(vgg16,"vgg16.pth")
torch.save(vgg16,"vgg16.pkl")

且不同后缀保存的文件大小也完全相等:

在这里插入图片描述
这样看的话只保存参数只比保存整个网络模型小 7 KB,似乎也不差这点存储空间。

3、模型和参数是可以打印的

通过打印模型就可以清晰地看到模型的结构。

打印模型:print(model)
打印模型参数:print(vgg16_para)

有时候无法用 print(model) 查看模型结构,也可以用 model.children() 方法逐层打印模型:

for i in model.children():
    print(i)

两种方法得到的模型结构是一样的。

网络模型:

# print(model)

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

参数:

可以看到是以 OrderedDict 格式保存的,数据有很多,图片放不下。
在这里插入图片描述

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

ctrl A_ctrl C_ctrl V

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值