1. 仅保存权重信息
# 模型路径
path = "state_dict_model.pt"
# 保存
torch.save(model.state_dict(), path)
# 加载
model = Network()
# 将训练好的权重加载到模型中
model.load_state_dict(torch.load(path))
2. 保存全部信息
# 对整个模型进保存和加载
path = "entire_model.pt"
# 保存模型
torch.save(model, path)
# 加载模型
model = torch.load(path)
3. 保存checkpoint
# 保存checkpoint
path = 'model.pt'
torch.save(
{
'epoch':epochs,
'model_state_dict': model.state_dict(),
'optimizer_state_dict':optimizer.state_dict(),
'loss': loss_fn
},path
)
# 加载
model = Network(input_num)
optimizer = optimizer = torch.optim.SGD(model.parameters(), lr=lr)
checkpoint = torch.load(path)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
epoch = checkpoint["epoch"]
loss = checkpoint["loss"]
4. 其他测试
当我们打印模型时:
net = MyModel(3)
print(net.state_dict().items()) # 输出模型每一层的权重
print(net.state_dict()) # 输出模型每一层的权重