来源 动手学深度学习
模型
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.hidden = nn.Linear(20, 256)
self.output = nn.Linear(256, 10)
def forward(self, x):
return self.output(F.relu(self.hidden(x)))
net = MLP()
X = torch.randn(size=(2, 20))
Y = net(X)
保存
torch.save(net.state_dict(), 'mlp.params')
加载
clone = MLP()
clone.load_state_dict(torch.load('mlp.params'))
clone.eval()
验证
Y_clone = clone(X)
Y_clone == Y