1. 模型保存
使用torch.save(model.state_dict(), path)将model模型保存为.pth文件。
torch.save(model.state_dict(), 'model/my_model.pth')
2. 模型载入
模型实例化后,使用torch.load(path)载入模型。
# 定义模型
model = Net()
# 模型载入
model.load_state_dict(torch.load('model/my_model.pth'))
1. 模型保存
使用torch.save(model.state_dict(), path)将model模型保存为.pth文件。
torch.save(model.state_dict(), 'model/my_model.pth')
2. 模型载入
模型实例化后,使用torch.load(path)载入模型。
# 定义模型
model = Net()
# 模型载入
model.load_state_dict(torch.load('model/my_model.pth'))