保存模型:
torch.save({
'epoch': epoch + 1,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
}, datadir)
加载模型
model = model_class(num_classes=num_classes) # 定义模型
state = torch.load(datadir)
model.load_state_dict(state['state_dict'])