有两套模板,配套使用:
模板一:直接加载参数模型
# 保存权重
torch.save(model, save_dirt)
# 载入模型
model = torch.load(save_dirt)
model.to(device)
# 替换代码
model = torch.load(save_path, map_location = device)
模板二:分别加载网络的结构和参数
# 保存
torch.save(model.state_dict(), save_path)
# 加载参数给模型
# 1、先实例化模型
model = Model()
# 2、根据权重文件将参数赋值该创建好的模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.load('./weight/best.pkl', map_location = device)
model.load_state_dict(checkpoint)# 尝试将权重文件载入实例化为模型中的model