从源码解读Faster-RCNN--(2)整体结构

通过上一篇文章,我们基本把faster rcnn中用到的知识点过了一编,其实也没多少东西,从这一篇往后,我们开始做code reading.
代码来源simple-faster-rcnn-pytorch
这部分我们主要阅读代码train.py和trainer.py

训练器

首先找到train.py文件

def train(**kwargs):#rows 50
	faster_rcnn = FasterRCNNVGG16()
	trainer = FasterRCNNTrainer(faster_rcnn).cuda()
	....
	trainer.train_step(img, bbox, label, scale)#rows 81

通过这几行代码,我们可以知道,训练器tranier是我们主要查看对象,而对img和后面跟着的bbox和label应该是对应的ground truth,而FasterRCNNVGG16里面应该包含对应主体网络,而train_step里面应该包含前向和反向传播过程。
所以我们重点看trainer
在trainer.py中首先查看train_step的实现

def train_step(self, imgs, bboxes, labels, scale):#rows 167
    self.optimizer.zero_grad()
    losses = self.forward(imgs, bboxes, labels, scale)
    losses.total_loss.backward()
    self.optimizer.step()
    self.update_meters(losses)
    return losses

果然包含其前向和反向传播的过程,我们这里重点看forward过程

class FasterRCNNTrainer(nn.Module):#rows 25
	def __init__(self, faster_rcnn):#rows 42
		self.faster_rcnn = faster_rcnn
		...
	def forward(self, imgs, bboxes, labels, scale):#rows 65
		...
		features = self.faster_rcnn.extractor(imgs)#rows 97
        rpn_locs, rpn_scores, rois, roi_indices, anchor = \
            self.faster_rcnn.rpn(features, img_size, scale)
        ...
      	sample_roi, gt_roi_loc, gt_roi_label = self.proposal_target_creator(...)#rows 112
      	...
      	roi_cls_loc, roi_score = self.faster_rcnn.head(...)#rows 120
      	gt_rpn_loc, gt_rpn_label = self.anchor_target_creator(...)#rows 126  		
      	####
      	后面是损失函数
      	####
      	....

在forward函数中,我们就可以看到大致的网络传播结构,首先是feature map通过extractor来进行提取,然后生成rpn,然后对生成的bbox找到ground truth,然后对每个bbox再进行类型的预测,然后再计算损失,这就是简单的网络传播结构。

训练器中涉及到的主体网络和相关接口

通过上面的解析,我们有了整体的了解,然后我们从接口方面再进行剖析,首先上面的forward中涉及到了几个很重要的接口,首先是和网络相关的接口,

self.faster_rcnn.extractor()#rows 97
self.faster_rcnn.rpn()#rows 99
self.faster_rcnn.head()#rows 120

然后是和ground truth有关的接口,

self.proposal_target_creator()#rows 112
self.anchor_target_creator()#rows 126

因此我们就从这几个接口作为解析网络的入口。

主体网络

主体网络在model/faster_rcnn_vgg16.py中,我们从上面的接口开始分析网络,
extractor()
通过代码搜索,我们可以定位到extractor的位置

extractor, classifier = decom_vgg16()#rows 63

找到decom_vgg16函数

def decom_vgg16():#rows 12
	...
	 model = vgg16(pretrained=False)#rows 15
	...
	features = list(model.features)[:30] #rows 21
	classifier = model.classifier
	....
	return nn.Sequential(*features), classifier

通过上面这里的特征提取网络从预训练好的网络VGG16中提取出来的,通过代码我们也可以知道对应的层数。
rpn()
rpn()的位置:

rpn = RegionProposalNetwork()#rows 65

可以知道这就是我们上一篇博客中提到的RPN网络,在model/region_proposal_network.py文件中,这个留作后续分析。
head()
head()的位置

head = VGG16RoIHead()#rows 72

找到VGG16RoIHead类

class VGG16RoIHead(nn.Module):#rows 86
	...
	def forward(self, x, rois, roi_indices):#rows 117
		...
		fc7 = self.classifier(pool) #rows 135
		roi_cls_locs = self.cls_loc(fc7)
		roi_scores = self.score(fc7)
	    return roi_cls_locs, roi_scores

从上面代码我们可以知道这个就是对给定的ROI预测其位置和分类的网络,也就是主体网络最后用到的那个网络。

Ground Truth

知道了相关的网络接口,那么如果求损失,我们需要知道grond truth,怎么获取呢?我们需要对上面的grond truth的接口进行分析:
在trainer.py中进行定位

self.anchor_target_creator = AnchorTargetCreator()#rows 50
self.proposal_target_creator = ProposalTargetCreator()

我们很快找到其都包含在model/utils/creator_tool.py文件中
anchor_target_creator

class AnchorTargetCreator(object):#rows 136
	def __call__(self, bbox, anchor, img_size):#rows 170
		...
		return loc, label

通过对代码分析,我们可以了解到其可以对输入的anchor产生器对应真实的loc和label
proposal_target_creator

class ProposalTargetCreator(object):#rows 8
	def __call__(self, roi, bbox, label,...)#rows 43
		...
		return sample_roi, gt_roi_loc, gt_roi_label

通过观测该函数我们可以知道其是对给定的roi产生真实的bbox位置和标签。

损失函数

通过对上面分析,特别是对trainer.py中的forward函数的观察,我们可以知道,后面这两个ground truth生成器都是对self.faster_rcnn.rpn网络中的求得的结果找到其对应的ground truth,进而求损失:

def forward(self, imgs, bboxes, labels, scale):#rows 65
	...
    rpn_locs, rpn_scores, rois, roi_indices, anchor = \
       self.faster_rcnn.rpn(features, img_size, scale)#rows 99
    ...
    roi = rois #rows 107
    ...
    sample_roi, gt_roi_loc, gt_roi_label = self.proposal_target_creator(
        roi,
        at.tonumpy(bbox),
        at.tonumpy(label),...)#rows 112
    ...
    gt_rpn_loc, gt_rpn_label = self.anchor_target_creator(
        at.tonumpy(bbox),
        anchor,
        img_size)
	...
    roi_cls_loc, roi_score = self.faster_rcnn.head(
        features,
        sample_roi,
        sample_roi_index)#rows 120
     ...
     roi_cls_loss = nn.CrossEntropyLoss()(roi_score, gt_roi_label.cuda())#158

那么对于最后的分类网络的损失呢,就是head网络输出,然后求损失了,最后简易总结如下:
faster rcnn loss
文章中有错误,还请您指出,一起进步。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值