Pytorch之保存读取模型
发布时间:2018-08-22 14:47,
浏览次数:1038
, 标签:
Pytorch
目录
转自这里
pytorch保存数据
pytorch读取数据
pytorch保存数据
pytorch保存数据的格式为.t7文件或者.pth文件,t7文件是沿用torch7中读取模型权重的方式。而pth文件是python中存储文件的常用格式
。而在keras中则是使用.h5文件。
# 保存模型示例代码 print('===> Saving models...') state = { 'state':
model.state_dict(), 'epoch': epoch # 将epoch一并保存 } if not
os.path.isdir('checkpoint'): os.mkdir('checkpoint') torch.save(state,
'./checkpoint/autoencoder.t7')
保存用到torch.save函数,注意该函数第一个参数可以是单个值也可以是字典,字典可以存更多你要保存的参数(不仅仅是权重数据)。
pytorch读取数据
pytorch读取数据使用的方法和我们平时使用预训练参数所用的方法是一样的,都是使用load_state_dict这个函数。
下方的代码和上方的保存代码可以搭配使用。
print('===> Try resume from checkpoint') if os.path.isdir(