日萌社
人工智能AI:Keras PyTorch MXNet TensorFlow PaddlePaddle 深度学习实战(不定时更新)
保存没有压缩的原始模型和及其模型状态、保存压缩后的模型和及其模型状态、加载没有压缩的原始模型文件和及其模型状态、加载压缩后的模型和及其模型状态
Pytorch:模型保存与加载方式
1.保存模型权重
torch.save(model.state_dict(), "./model_save/xx.pt")
2.加载模型权重
model = 模型类Model()
model.load_state_dict(torch.load("./model_save/xx.pt"))
3.例子
hidden = torch.zeros(num_layers, 1, hidden_size).to(device)
rnn = RNN(n_letters, n_hidden, n_categories).to(device)
rnn.load_state_dict(torch.load("./model_save/rnn_embedding.pt"))