模型参数
1.仅保存学习到的参数,用以下命令
torch.save(model.state_dict(), PATH)
2.加载model.state_dict
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
整个模型状态
1.保存整个模型的状态:
torch.save(model,PATH)
2.加载整个模型状态:
model = torch.load(PATH)
model.eval()