1 pytorch 保存模型参数
1 保存模型参数方法: torch.save(model.state_dict(), path)
2 加载
1 定义模型
定义与原模型一致的模型,并进行实例化:
model = Net()
2 加载模型
model.load_state_dict(torch.load(path))
3 在加载模型的基础上继续训练
train(model, train_load, epoch)
记着更换数据集
1 保存模型参数方法: torch.save(model.state_dict(), path)
1 定义模型
定义与原模型一致的模型,并进行实例化:
model = Net()
2 加载模型
model.load_state_dict(torch.load(path))
3 在加载模型的基础上继续训练
train(model, train_load, epoch)
记着更换数据集