pytorc保存及加载模型参数方法
pytorch保存模型、参数的方法
方法一:只保存模型的参数
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu' )
filepath = 'model.dat'
# 保存参数
torch.save(model.state_dict(), filepath)
# 加载模型参数 , map_location: 把数据加载到哪个device(GPU或CPU)
model.load_stat
原创
2021-11-11 17:45:20 ·
1556 阅读 ·
0 评论