1 保存整个模型
torch.save(model, 'model/model.pth')#保存整个模型结构和参数
torch.save(model.state_dict(),'model/model_weights.pth')#仅仅保存整个模型的参数
2 重新加载
# 方法1:加载整个模型结构和参数
new_model2 = torch.load('model/model.pth')
# 方法2:加载整个模型的参数
new_model = Net()
new_model = new_model.to(device)
new_model.load_state_dict(torch.load('model/model_weights.pth'))
3 调用测试函数预测
#方法1
new_model2 = torch.load('model/model.pth')
new_model2.eval()
new_model2 = new_model2.to(device)
print(test(test_dl, new_model2)) 输出正确率和损失
#方法2
new_model = Net()
new_model = new_model.to(device)
new_model.load_state_dict(torch.load('model/model_weights.pth'))
new_model.eval()
print(test(test_dl, new_model)) 输出正确率和损失
4 保存最优参数
import copy
best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0
train_loss = []
train_acc = []
test_loss = []
test_acc = []
for epoch in range(epochs):
epoch_loss, epoch_acc = train(train_dl, model, loss_fn, optimizer)
epoch_test_loss, epoch_test_acc = test(test_dl, model)
train_loss.append(epoch_loss)
train_acc.append(epoch_acc)
test_loss.append(epoch_test_loss)
test_acc.append(epoch_test_acc)
template = ("epoch:{:2d}, train_loss: {:.5f}, train_acc: {:.1f}% ,"
"test_loss: {:.5f}, test_acc: {:.1f}%")
print(template.format(
epoch, epoch_loss, epoch_acc*100, epoch_test_loss, epoch_test_acc*100))
if epoch_test_acc > best_acc:
best_acc = epoch_test_acc
best_model_wts = copy.deepcopy(model.state_dict())
print("Done!")
# 保存最优参数
model.load_state_dict(best_model_wts)
torch.save(model, 'model/best_model.pth')
model.eval()
5 保存和恢复检查点checkpoint
一般保存和加载检查点,模型进行推理或恢复训练可能有助于继续上次离开的位置。在保存一般检查点时,您必须保存的不仅仅是模型的state_dict。保存优化器的state_dict也很重要,因为这包含随着模型参数的更新。保存的内容包括离开的时代、最新记录的训练损失、外部层以及更多基于您自己的算法。
PATH = "model/model_checkpoint_{}.pt"
for epoch in range(epochs):
epoch_loss, epoch_acc = train(train_dl, model, loss_fn, optimizer)
epoch_test_loss, epoch_test_acc = test(test_dl, model)
train_loss.append(epoch_loss)
train_acc.append(epoch_acc)
test_loss.append(epoch_test_loss)
test_acc.append(epoch_test_acc)
template = ("epoch:{:2d}, train_loss: {:.5f}, train_acc: {:.1f}% ,"
"test_loss: {:.5f}, test_acc: {:.1f}%")
print(template.format(
epoch, epoch_loss, epoch_acc*100, epoch_test_loss, epoch_test_acc*100))
if epoch_test_acc > best_acc:
best_acc = epoch_test_acc
best_model_wts = copy.deepcopy(model.state_dict())
# 可以设置多少轮保存一次
torch.save({
'epoch': epochs,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, PATH.format(epoch))
print("Done!")
# 恢复检查点
#首先初始化模型和优化器,然后回复检查点
model = Net()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
print(epoch)
print(optimizer)
print(model.eval())
输出如下:
100
Adam (
Parameter Group 0
amsgrad: False
betas: (0.9, 0.999)
capturable: False
eps: 1e-08
foreach: None
lr: 0.0005
maximize: False
weight_decay: 0
)
Net(
(conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1))
(conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1))
(conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
(fc1): Linear(in_features=57600, out_features=1024, bias=True)
(fc2): Linear(in_features=1024, out_features=2, bias=True)
)