pytorch模型的保存与加载注意事项:

保存和加载模型:
torch.save(net,'./model.pth')   #保存整个模型及其参数
net=torch.load('./model.pth')  #加载整个模型及其参数
#或者
torch.save(net.state_dict(),'./model-dict.pth')#仅仅保存模型参数
net.load_state_dict(torch.load('./model-dict.pth')) #仅仅加载模型参数(所以需要事先定义一个模型net)

net.load_state_dict()和torch.load()的不同,前者需要你先定义一个模型,然后再load_state_dict()。
torch.load()直接加载整个模型,会把模型和模型参数一起load进来。完成了模型的定义和加载参数的两个过程。

需要注意的是,在保存模型之前,需要把模型进行eval, 即把模型从训练阶段转化为测试阶段,固定当下的模型参数,用于接下来的模型预测。如果不指定模型eval模式,那么加载回来的模型并不是和原先保存的模型相同。

简单说,原先的net在保存之前,要eval一下,load之后的net也要eval一下,把所有参数freeze掉。才保证两个net完全相同(输入相同tensor得到完全一致的结果)。

因为在模型的训练阶段,在进行有BN层或者有Dropout层的模型训练中,获取的批次数据属性(均值、方差)会被记录下来,用于对测试数据的标准化。或者对于Dropout层,在训练的阶段会有一些神经元权重被置零,但是在测试阶段,这些神经元又被重新使用。如果不进行model.eval()的话,那么每次测试阶段这些参数的值会在前向传播的时候发生改变。导致模型不稳定。

使用PyTorch进行训练和测试时一定注意要把实例化的model指定train/eval,eval()时,框架会自动把BN和DropOut固定住,不会取平均,而是用训练好的值,不然的话,一旦test的batch_size过小,很容易就会被BN层导致生成图片颜色失真极大。

model.train()
启用 BatchNormalization 和 Dropout
model.eval()
不启用 BatchNormalization 和 Dropout

总结一下:训练完train样本后,生成的模型model要用来测试样本。在model test之前,需要加上model.eval(),否则的话,有输入数据,即使不训练,它也会改变权值。这是model中含有batch normalization层所带来的的性质。当训练集和测试集的样本分布是不一样的,尤其需要注意这一点。

评论 16
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值