mmdetection训练得到的权重/checkpoints文件分析和修改

这篇文章对mmdetection(包括mmlab的其他例如mmclassification等)训练得到的模型权重,或者说checkpoints文件进行分析,一般模型保存在work-dir文件夹下,具体路径要参考训练用到的config,即配置文件。保存的模型一般是.pt的文件。

读取.pth文件具体数值

修改.pth文件具体数值(比如修改卷积核通道数)

读取.pth文件具体数值

.pt模型文件读取方法

这种模型文件可以用torch.load()函数进行解析

import torch

pth_path = 'work-dir/your_check_point.pt'
model = torch.load(pth_path)

这里我们就可以看到这个model实际上不是什么复杂的东西,就是一个很大的dict

这个model一般包括三个key、value。

meta

第一个:meta,包含一些基本信息。就是告诉你这个模型是在什么背景下被训练得到的,用的mmdet是什么版本,随机种子seed是多少,config是什么,方便你复现复刻出来这个model 

state_dict 

这个是模型关键。一般网上下载的预训练权重只有这个,其是一个大的OrderedDict里面包含了这个模型按顺序得到的各层参数,看下图就明白个大概了。

一般要利用一个checkpoint(.pt的模型权重文件) ,也就是主要读取这里面的信息,来进行refine或者infer。

optimizer 

里面存放的是优化器的状态,方便用这个.pt文件进行resume,即意外中断实验的时候进行继续实验,结合mmdet的train.py里的resume_from命令理解。

修改.pth文件具体数值(比如修改卷积核通道数)

有时候我们修改了网络部分,会导致预训练权重的shape跟网络修改后的shape不匹配。最经典的例子就是我们希望输入图片从3channels变成6channels。就会报诸如

torch: input tensor shape not match

的错误,那么这个时候为了正常用预训练模型,我们就需要手动去修改模型权重。当然如果相关module的封装特别好比如MMdetection的一些backbone和neck module,是可以手动选择in_channels来自动匹配input的shape修改的,但是这种匹配也是建立在已封装了修改具体权重的基础上,实际上我们更希望手动去控制一些修改权重的方法(比如初始化法)

  • 2
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值