pytorch checkpoint_记录一次pytorch加载模型的采坑

22b70cff595de872de9780b5da6ed5ca.png

因为我有多块gpu,想都用来做训练,查阅资料,得用到dataparallel这个方法,刚好从github上找到了一个比较完整的pytorch图像分类代码也是使用的dataparallel方法,就直接拿来试验。当训练完毕储存模型,想再次加载测试一下结果的时候,evaluate的结果基本都是1%的情况,这让我很费解,训练完毕的时候可是有73%呢,又是就开始了长达整个半天的采坑之旅。

原始加载模型代码如下:

# Load checkpoint.
 

1. 首先我注意到model.train()和model.eval(),主要是针对model 在训练时和评价时不同的 Batch NormalizationDropout 方法模式。

model.eval()时,pytorch会自动把BN和DropOut固定住,不会取平均,而是用训练好的值。不然的话,一旦test的batch_size过小,很容易就会被BN层导致生成图片颜色失真极大。而储存模型是在test()函数之后,这已经将模型切换至eval模式,而调用的时候test时又将模型切换成eval模式,切换了两次,这可能会导致错误,因此直接将储存模型写到test()函数之前,但并没有效果。

for 

2. 后来在网上查询,得知对于dataparallel模式,model.module才是权值真正所在的地方,因此我就在储存的时候改动代码,储存model.module而不是model:

save_checkpoint

然后加载模型的时候,我注意到源代码是override模式,我们直接使用简单粗暴的方法,对模型全部load,改动代码如下:

args.checkpoint = os.path.dirname(args.resume)
checkpoint = torch.load(args.resume)
best_acc = checkpoint['best_acc']
start_epoch = checkpoint['epoch']


# 注释掉overwrite模式
# model_dict=model.state_dict()
# # 1. filter out unnecessary keys
# state_dict = {k: v for k, v in state_dict.items() if k in model_dict}
# # 2. overwrite entries in the existing state dict
# model_dict.update(state_dict)
model.load_state_dict(checkpoint)

结果这个时候反而报错,找不到相应权值的字段了。

3. 这个时候意识到,可能本来加载的时候,checkpoint中就没有权值字段, 而这情况恰恰又被overwrite模式直接忽略了,所以通过load之后并没有对初始化的模型进行权值的复写,而后来的evaluate过程使用的是初始化的模型进行测试,当然结果很差

save_checkpoint

又重新检查save模型的代码,我发现checkpoint实际上对待储存的模型进行了包装,添加了acc loss等字段,而真正的权值字典被包装进了state_dict字段,这下就可以正确改动代码了,代码如下:

assert 

最终的测试结果如下:

249ee2328144adc5349c475e552010fc.png
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值