保存模型
第一种:保存整个模型,按道理应该不用自己再导入网络框架,但我在实际应用过程中如果没有导入整个网络框架是会报错的。
torch.save(net,dir)
#其中net是你代码中自定义的网络入口(下同)
#dir是你将要保存的路径加自定义名字(下同)
#eg:"D:/pytroch/mode.pkl"
第二种:保存模型参数,网络框架还需自己构建。
torch.save(net.state_dict(),dir)
加载模型
第一:
net = torch.load(dir)
注意:它并不是独立于原先代码就可以运行,需要获取原来网络的框架,并且数据格式必须一致。通过一下例子进一步讲解。
from ... import Net
#在另一段代码运行的时候需要加载原来训练的网络框架,不然会报错
image.unsqueeze(0)
#训练的时候一般是(batch_sieze,x,y,z),但推理的时候只有(x,y,z)
net.to(device)
image.to(device)
#你训练是用CPU还是GPU,对应声明即可
第二种:
一般来说大家都是推荐这一种,pytorch官方给出的很多预训练模型也都是用这种方法加载的。
net = Net() #我们必须实例化原来函数的网络架构
temp = torch.load(dir)
net.load_state_dict(temp)
这样net就是一个即有网络框架又有模型参数的网络了。
值得一提的是,无论哪一种,其实都是可以进行网络裁剪的。
nn.Sequential(*list(net.children())[n:m])
#两种都适用,更具体可以看我Pytorch迁移学习那一章。