(pytorch-深度学习系列)读取和存储数据-学习笔记

读取和存储数据

我们可以使用pt文件存储Tensor数据:

import torch
from torch import nn

x = torch.ones(3)
torch.save(x, 'x.pt')

这样我们就将数据存储在名为x.pt的文件中了
我们可以从文件中将该数据读入内存:

x2 = torch.load('x.pt')
print(x2)

还可以存储Tensor列表到文件中,并读取:

y = torch.zeros(4)
torch.save([x, y], "xy.pt")
xy_list = torch.load("xy.pt")
print(xy_list)

不仅如此,还可以存储一个键值为Tensor变量的字典:

torch.save({'x':x, 'y':y}, "xy_dict")
xy_dict = torch.load("xy_dict")
print(xy_dict)

对模型参数进行读写:

对于Module类的对象,我们可以使用model.parameters()函数来访问模型的参数。而state_dict函数将会返回一个模型的参数名称到参数Tensor对象的一个字典对象。

class my_module(mm.Module):
	def __init__(self):
		super(my_module, self)
		self.hidden = nn.Linear(3, 2)
		self.action = nn.ReLU()
		self.output = nn.Linear(2, 1)

	def forward(self, x):
		middle = self.action(self.hidden(x))
		return self.output(middle)	

net = my_module()
net.state_dict()

输出:

OrderedDict([('hidden.weight', tensor([[ 0.2448,  0.1856, -0.5678],
                      [ 0.2030, -0.2073, -0.0104]])),
             ('hidden.bias', tensor([-0.3117, -0.4232])),
             ('output.weight', tensor([[-0.4556,  0.4084]])),
             ('output.bias', tensor([-0.3573]))])

但是,只有具有可变参数(可学习参数)的网络层才会在state_dict中,

同样的,优化器(optim)也有一个state_dict,这个函数返回一个字典,该字典包含优化器的状态以及其超参数信息:

optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
optimizer.state_dict()

输出:

{'param_groups': [{'dampening': 0,
   'lr': 0.001,
   'momentum': 0.9,
   'nesterov': False,
   'params': [4736167728, 4736166648, 4736167368, 4736165352],
   'weight_decay': 0}],
 'state': {}}

那么就可以通过保存模型的state_dict来保存模型

torch.save(net.state_dict(), PATH)

model = my_module(*args, **kwargs)
model.load_state_dict(torch.load(PATH))

还可以直接保存整个模型:

torch.save(model, PATH)
model = torch.load(PATH)
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值