从头学pytorch(十二):模型保存和加载

本文介绍了PyTorch中模型的保存和加载,包括Tensor的读写,state_dict的使用,以及两种模型保存和加载的方法。重点讲解了如何使用torch.save()和Module.load_state_dict()进行模型参数的保存和加载,确保模型在不同设备间的迁移和复用。
摘要由CSDN通过智能技术生成

模型读取和存储

总结下来,就是几个函数

  1. torch.load()/torch.save()

通过python的pickle完成序列化与反序列化.完成内存<-->磁盘转换.

  1. Module.state_dict()/Module.load_state_dict()

state_dict()获取模型参数.load_state_dict()加载模型参数

读写Tensor

我们可以直接使用save函数和load函数分别存储和读取Tensorsave使用Python的pickle实用程序将对象进行序列化,然后将序列化的对象保存到disk,使用save可以保存各种对象,包括模型、张量和字典等。而laod使用pickle unpickle工具将pickle的对象文件反序列化为内存。
下面的例子创建了Tensor变量x,并将其存在文件名同为x.pt的文件里。

import torch
from torch import nn

x = torch.ones(3)
torch.save(x, 'x.pt')

然后我们将数据从存储的文件读回内存。

x2 = torch.load('x.pt')
x2

输出:

tensor([1., 1., 1.])

我们还可以存储一个

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值