faster rcnn 实现fine-tune

faster rcnn 实现fine-tune

源代码的网址:https://github.com/jwyang/faster-rcnn.pytorch
这个网址里面有很多训练好的模型,有voc2007 的 还有coco的,视情况选择。
接下来介绍一下背景,我下载的是voc2007的res101模型,目标是实现人 和 皮卡丘 的检测,因为原来的数据库里面没有安全帽的类别,所以需要增加这个类别,同时减去没有必要的类别,并且还要实现预训练模型的fine-tune。
1。增加类别和减少类别,直接在代码里面修改就可以了,百度有很多方法,灰常简单,这里简单说一下,train 和 test 的时候,要改的类别在lib/datasets/pascal_voc.py里面修改,demo直接在demo里面修改就可以了。
技巧:打开对应的代码,ctrl+f 查找background 或者‘person’随便一个关键词 可以快速找到
2。如何用别人训练好的模型fine-tune:
(1)我们首先要分清两个东西:
一个是自己创建的网络,就是你即将训练的网络,代码里面表现为 XXNET=XXX,为了方便下面的介绍会定义为A
一个是导入的模型,即.pth文件,这个定义为B
本例中,我的.pth文件的内容是包含自己建立的新的网络的,所谓包含就是B含有很多东西 a,b,c,d,e,f,g…,而A呢只有a,b,c,d…
如何查看上面说的东西呢,这接print(自己的网络) 查看A的内容,print(load的.pth)查看B网络。
直接上代码:
比较下面几个代码:
print(checkpoint):整个模型的所有东西,包括对应的权值,以及网络结构。
print(fasterRCNN):整个网络的详细结构,没有打印出权值
print(fasterRCNN.state_dict()):网络的结构关键词,不是很具体,但是是我需要的
print(checkpoint[‘model’].items):和上面一句类似,精简的,我们需要的
注:代码块的部分是在,train__val.py 中的 resume附近改的,那部分涉及到导入模型

save_model =checkpoint['model']   #这里返回的是一个字典,里面包含 键和对应的值,其实就是我们的网络的名字和对应的权值
model_dict = fasterRCNN.state_dict()  #这个我们只能得到对应的网络的名字

state_dict={}
state_dict3 = {}
i=1
for k, v in save_model.items():
    if k in model_dict.keys():
 #因为save_model 是包含model_dict的,所以,我们要判断的是,网络在两个变量里面同时存在,接着,把对应的网络层的值传递即可。
        if i==1:
            state_dict = {k: v}
            i=0

            continue
        if str(k)!='RCNN_cls_score.bias' and str(k)!='RCNN_cls_score.weight' and str(k)!='RCNN_bbox_pred.weight' and str(k)!='RCNN_bbox_pred.bias':
  #由于新的网络和原来的.pth有一些参数上的不同,把参数不同的网络直接丢弃即可。至于那些参数不同,可以参考代码段前的几个print结果,或者运行代码从报错里面找。
            state_dict1 = {k: v}
            state_dict=dict(state_dict,**state_dict1)
        # if str(k) == 'RCNN_cls_score.bias':
        #     b=np.zeros((3))
        #     a=torch.tensor(b)
        #     state_dict2={k:a}
        #     state_dict = dict(state_dict, **state_dict2)
        #     print('1111111111111111111111111111111111111111111111111111111111111111111')
        # if str(k) == 'RCNN_cls_score.weight':
        #     b=np.zeros((3,2048))
        #     a=torch.tensor(b)
        #     state_dict2={k:a}
        #     state_dict = dict(state_dict, **state_dict2)
        #     print('222222222222222222222222222222222222222222222222222222222222222222222')
        # if str(k) == 'RCNN_bbox_pred.weight':
        #     b=np.zeros((12,2048))
        #     a=torch.tensor(b)
        #     state_dict2={k:a}
        #     state_dict = dict(state_dict, **state_dict2)
        #     print ('33333333333333333333333333333333333333333333333333333333333333333333333')
        # if str(k) == 'RCNN_bbox_pred.bias':
        #     b = np.zeros((12))
        #     a=torch.tensor(b)
        #     state_dict2={k:a}
        #     state_dict = dict(state_dict, **state_dict2)
        #     print('4444444444444444444444444444444444444444444444444444444444444444444444444444')

model_dict.update(state_dict)
#将符合要求的网络组成一个词典,并且更新,用下面的语句送入新的网络里面。
#print('====================================================================\n',state_dict.keys())
fasterRCNN.load_state_dict(model_dict)

(2)opitimizer这个部分我没有导入,因为目前我还不知道应该修改那些参数,所以干脆丢了(没丢之前,修改网络参数之后仍然会报错,丢了之后,焕然一新)

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值