由于经常用,留作存档。
- 整个模型
torch.save(model, path) # 直接保存整个模型
model = torch.load(path) # 直接加载模型
- 模型参数
torch.save(model.state_dict(), path) # 保存模型的参数
model = Model() # 先初始化一个模型
model.load_state_dict(torch.load(path)) # 再加载模型参数
由于经常用,留作存档。
torch.save(model, path) # 直接保存整个模型
model = torch.load(path) # 直接加载模型
torch.save(model.state_dict(), path) # 保存模型的参数
model = Model() # 先初始化一个模型
model.load_state_dict(torch.load(path)) # 再加载模型参数