# save
torch.save(model.state_dict(), PATH) //模型参数保留在PATH下
# load
model = MyModel(*args, **kwargs) //加载模型
model.load_state_dict(torch.load(PATH)) //加载模型参数
model.eval() //转换成验证、测试模式
写PATH
checkpoint_path = os.path.join('./checkpoint',datetime.now().strftime('%A_%d_%B_%Y_%Hh_%Mm_%Ss'))
if not os.path.exists(checkpoint_path):
os.makedirs(checkpoint_path)
//每10轮保存一个模型
if not epoch % 10:
torch.save(net.state_dict(), os.path.join(checkpoint_path,str(epoch)+'.pth'))