Pytorch读写Tensor
摘自:动手学深度学习PYTORCH版(DEMO)
链接:https://link.zhihu.com/?target=https%3A//github.com/OUCMachineLearning/OUCML/blob/master/BOOK/Dive-into-DL-PyTorch.pdf
import torch
from torch import nn
x = torch.ones(3)
torch.save(x, 'x.pt')
然后将数据从存储的文件读回内存。
x2 = torch.load('x.pt')
x2
输出:
tensor([1., 1., 1.])
还可以存储一个Tensor列表并读回内存
y = torch.zeros(4)
torch.save([x, y], 'xy.pt')
xy_list = torch.load('xy.pt')
xy_list
输出:
[tensor([1., 1., 1.]), tensor([0., 0., 0., 0.])]
存储并读取一个从字符串映射到Tensor的字典
torch.save({
'x': x,
'y': y
}, 'xy_dict.pt')
xy = torch.load('xy_dict.pt')
xy
输出:
{'x': tensor([1., 1., 1.]), 'y': tensor([0., 0., 0., 0.])}
Pytorch读写模型
state_dict
class MLP(nn.Module):
def __init__(self):
super(MLP, self).__init__()
self.hidden = nn.Linear(3, 2)
self.act = nn.ReLU()
self.output = nn.Linear(2, 1)
def forward(self, x):
a = self.act(self.hidden(x))
return self.output(a)
net = MLP()
net.state_dict()
输出:
OrderedDict([('hidden.weight', tensor([[ 0.2448, 0.1856, -0.5678],
[ 0.2030, -0.2073, -0.0104]])),
('hidden.bias', tensor([-0.3117, -0.4232])),
('output.weight', tensor([[-0.4556, 0.4084]])),
('output.bias', tensor([-0.3573]))])
optimizer = torch.optim.SGD(net.parameters(), lr=0.001,
momentum=0.9)
optimizer.state_dict()
输出:
{'param_groups': [{'dampening': 0,
'lr': 0.001,
'momentum': 0.9,
'nesterov': False,
'params': [4736167728, 4736166648, 4736167368, 4736165352],
'weight_decay': 0}],
'state': {}}
保存和加载模型
保存
torch.save(model.state_dict(), PATH) # 推荐的文件后缀名是pt或pth
加载
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
实验:
X = torch.randn(2, 3)
Y = net(X)
PATH = "./net.pt"
torch.save(net.state_dict(), PATH)
net2 = MLP()
net2.load_state_dict(torch.load(PATH))
Y2 = net2(X)
Y2 == Y
输出:
tensor([[1],
[1]], dtype=torch.uint8)
官方文档https://pytorch.org/tutorials/beginner/saving_loading_models.html