torch文件保存与加载——【torch学习笔记】

模型文件保存与加载

引用翻译:《动手学深度学习》

到目前为止,我们讨论了如何处理数据,如何建立、训练和测试深度学习模型。然而,在某些时候,我们很可能对我们获得的结果感到满意,我们希望保存结果以便以后使用和分发。同样,当运行一个漫长的训练过程时,保存中间结果(检查点)是最好的做法,以确保我们不会在被服务器的电源线绊倒时失去几天的计算量。同时,我们可能想加载一个预训练的模型(例如,我们可能有英语的词嵌入,并将其用于我们花哨的垃圾邮件分类器)。对于所有这些情况,我们都需要加载和存储单个权重向量和整个模型。本节讨论这两个问题。

一、张量

在其最简单的形式中,我们可以直接使用保存和加载函数来分别存储和读取张量。这就像预期的那样工作。

import torch
from torch import nn
x=torch.arange(4,dtype = torch.float32)
torch.save(x,"x-file")

然后,我们把存储文件中的数据读回内存。

x2 = torch.load("x-file")
x2

输出:

tensor([0., 1., 2., 3.])

我们也可以存储一个张量列表并将其读回内存。

y = torch.zeros(4)
torch.save([x, y],'x-files')  # 保存
x2, y2 = torch.load('x-files')  # 读取
(x2, y2)

输出:

(tensor([0., 1., 2., 3.]), tensor([0., 0., 0., 0.]))

我们甚至可以写和读一个字典,从一个字符串映射到一个张量。这很方便,例如,当我们想读取或写入一个模型中的所有权重时。

mydict = {'x': x, 'y': y}
torch.save( mydict,'mydict.npy') # 存储字典型
mydict2 = torch.load('mydict.npy')
mydict2

输出:

{'x': tensor([0., 1., 2., 3.]), 'y': tensor([0., 0., 0., 0.])}

二、torch模型参数

保存单个权重向量(或其他张量)是很有用的,但如果我们想保存(以后再加载)整个模型的话,就会变得非常繁琐了。

毕竟,我们可能有数以百计的参数组散布在各个地方。编写一个脚本来收集所有的术语并将它们与一个架构相匹配是相当费力的。由于这个原因,Torch提供了内置的功能来加载和保存整个网络,而不仅仅是单个权重向量。

需要注意的一个重要细节是,这将保存模型参数而不是整个模型。也就是说,如果我们有一个3层的MLP,我们需要单独指定架构。

这样做的原因是,模型本身可以包含任意的代码,因此它们不能很容易地被序列化(对于编译过的模型,有一种方法可以做到这一点–关于它的技术细节,请参考Torch的文档)。

其结果是,为了恢复一个模型,我们需要在代码中生成架构,然后从磁盘加载参数。延迟初始化在这里是相当有利的,因为我们可以简单地定义一个模型,而不需要把实际的值放在那里。让我们从我们最喜欢的MLP开始。

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = nn.Linear(20, 256)
        self.output = nn.Linear(256, 10)
        self.relu = nn.ReLU()
    def forward(self, x):
        H_1 = self.relu(self.hidden(x))
        out = self.output(H_1)
        return out

net = MLP()
x = torch.randn(size=(2, 20))
y = net(x)

接下来,我们将模型的参数存储为一个文件,名称为mlp.params 。

torch.save(net.state_dict(), 'mlp.params')

为了检查我们是否能够恢复模型,我们实例化了一个原始MLP模型的克隆。与模型参数的随机初始化不同,这里我们直接读取存储在文件中的参数。

clone = MLP()
clone.load_state_dict(torch.load("mlp.params"))
clone.eval()

输出:

MLP(
  (hidden): Linear(in_features=20, out_features=256, bias=True)
  (output): Linear(in_features=256, out_features=10, bias=True)
  (relu): ReLU()
)

由于两个实例都有相同的模型参数,相同的输入x的计算结果应该是相同的。让我们来验证这一点。

yclone = clone(x)
yclone == y

输出:

tensor([[True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True]])

三、摘要

  • save和load函数可以用来对张量对象执行文件I/O。

  • load_parameters和save_parameters函数允许我们在Torch中保存网络的整个参数集。

  • 保存架构必须在代码中而不是在参数中进行。

四、练习

1、即使不需要将训练好的模型部署到不同的设备上,存储模型参数的实际好处是什么?

2、假设我们只想重复使用一个网络的部分内容,将其纳入不同结构的网络中。你如何在一个新的网络中使用,比如说以前的网络中的前两层。

3、你将如何保存网络结构和参数?你会对结构施加什么限制?

  • 1
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值