Pytorch读取存储

Pytorch读写Tensor

摘自:动手学深度学习PYTORCH版(DEMO)
链接:https://link.zhihu.com/?target=https%3A//github.com/OUCMachineLearning/OUCML/blob/master/BOOK/Dive-into-DL-PyTorch.pdf在这里插入图片描述

import torch
from torch import nn
x = torch.ones(3)
torch.save(x, 'x.pt')

然后将数据从存储的文件读回内存。

x2 = torch.load('x.pt')
x2

输出:

tensor([1., 1., 1.])

还可以存储一个Tensor列表并读回内存

y = torch.zeros(4)
torch.save([x, y], 'xy.pt')
xy_list = torch.load('xy.pt')
xy_list

输出:

[tensor([1., 1., 1.]), tensor([0., 0., 0., 0.])]

存储并读取一个从字符串映射到Tensor的字典

torch.save({
'x': x, 
'y': y
}, 'xy_dict.pt')
xy = torch.load('xy_dict.pt')
xy

输出:

{'x': tensor([1., 1., 1.]), 'y': tensor([0., 0., 0., 0.])}

Pytorch读写模型

state_dict

在这里插入图片描述

class MLP(nn.Module):
	def __init__(self):
		super(MLP, self).__init__()
		self.hidden = nn.Linear(3, 2)
		self.act = nn.ReLU()
		self.output = nn.Linear(2, 1)
		
	def forward(self, x):
		a = self.act(self.hidden(x))
		return self.output(a)
net = MLP()
net.state_dict()

输出:

OrderedDict([('hidden.weight', tensor([[ 0.2448, 0.1856, -0.5678],
[ 0.2030, -0.2073, -0.0104]])),
('hidden.bias', tensor([-0.3117, -0.4232])),
('output.weight', tensor([[-0.4556, 0.4084]])),
('output.bias', tensor([-0.3573]))])

在这里插入图片描述

optimizer = torch.optim.SGD(net.parameters(), lr=0.001,
momentum=0.9)
optimizer.state_dict()

输出:

{'param_groups': [{'dampening': 0,
'lr': 0.001,
'momentum': 0.9,
'nesterov': False,
'params': [4736167728, 4736166648, 4736167368, 4736165352],
'weight_decay': 0}],
'state': {}}
保存和加载模型

在这里插入图片描述
保存

torch.save(model.state_dict(), PATH) # 推荐的文件后缀名是pt或pth

加载

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))

实验:

X = torch.randn(2, 3)
Y = net(X)
PATH = "./net.pt"
torch.save(net.state_dict(), PATH)
net2 = MLP()
net2.load_state_dict(torch.load(PATH))
Y2 = net2(X)
Y2 == Y

输出:

tensor([[1],
[1]], dtype=torch.uint8)

在这里插入图片描述
官方文档https://pytorch.org/tutorials/beginner/saving_loading_models.html

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值