实战 | 源码入门之Faster RCNN

前言

学习深度学习和计算机视觉,特别是目标检测方向的学习者,一定听说过Faster Rcnn;在目标检测领域,Faster Rcnn表现出了极强的生命力,被大量的学习者学习,研究和工程应用。网上有很多版本的Faster RCNN的源码,但是很多版本代码太过于庞大,对新入门的学习者学习起来很不友好,在网上苦苦寻找了一番后终于找到了一个适合源码学习的Faster Rcnn的pytorch版本代码。

根据该版本的作者讲该代码除去注释只有两千行左右,并且经过小编的一番学习之后,发现该版本的代码真的是非常的精简干练,读起来“朗朗上口”,并且深刻的感觉到作者代码功底之深厚。在此先附上源码的地址(https://github.com/chenyuntc/simple-faster-rcnn-pytorch) ,并对源码作者(陈云)表示由衷的感谢和深深地敬意。

本文章主要的目的是对该版本代码的主要框架进行梳理,希望能够对一些想学习源码的读者有一定的帮助。

本文作者:白俊杰

代码的主要文件

-data文件中主要是文件的与dataset相关的文件

-misc中有下载caffe版本预训练模型的文件,可以不看

-model文件中主要是与构建Faster Rcnn网络模型有关的文件

-utils中主要是一些辅助可视化和验证的文件

-train.py是整个程序的运行文件,下面有一部分会做介绍

-trainer.py文件主要是用于训练,模型的损失函数的计算都在这个文件中

train

先来看一下train.py里的主要内容:

def train(train(**kwargs)):    #训练网络的主要内容(位于train.py文件中)
  opt._parse(kwargs)
  dataset = Dataset(opt)      #读取用于训练的图片及进行相关的预处理(在下文的dataset部分做详细介绍)
  dataloader = data_.DataLoader(dataset, \
                                batch_size=1, \
                                shuffle=True, \
                                # pin_memory=True,
                                num_workers=opt.num_workers)
  testset = TestDataset(opt)  #读取用于测试的图片及进行相关的预处理
  test_dataloader = data_.DataLoader(testset,
                                     batch_size=1,
                                     num_workers=opt.test_num_workers,
                                     shuffle=False, \
                                     pin_memory=True
                                     )
  faster_rcnn = FasterRCNNVGG16()    #网络结构,包含主要Extractor,RPN和RoIHead三部分结构。
  trainer = FasterRCNNTrainer(faster_rcnn).cuda()  #主要包含模型的训练过程的

  for epoch in range(opt.epoch):#开始迭代训练
        trainer.reset_meters()
        for ii, (img, bbox_, label_, scale) in tqdm(enumerate(dataloader)):
            scale = at.scalar(scale)
            img, bbox, label = img.cuda().float(), bbox_.cuda(), label_.cuda()
            trainer.train_step(img, bbox, label, scale)  #执行训练

从train.py中的主要函数可以看出,主要的步骤涉及训练数据和测试数据的预处理,网络模型的构建(Faster RCNN),然后就是迭代训练,这也是通用的神经网络搭建和训练的过程。在Faster Rcnn网络模型中主要包含Extractor、RPN和RoIhead三部分。网络中Extractor主要是利用CNN进行特征提取,网络采用的VGG16;RPN是候选区网络,为RoIHead模块提供可能存在目标的候选区域(rois);RoIHead主要负责rois的分类和微调。整体的框架图如下图所示:

图片来源于陈云的知乎

Dataset

在本版本的代码中读取的数据格式为VOC,Dataset和TestDataset类分别负责训练数据和测试数据的读取及预处理。在预处理部分主要的操作就是resize图像的大小、像素值的处理以及图像的随机翻转。主要的内容如下:


                
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值