通过上一篇文章,我们基本把faster rcnn中用到的知识点过了一编,其实也没多少东西,从这一篇往后,我们开始做code reading.
代码来源simple-faster-rcnn-pytorch
这部分我们主要阅读代码train.py和trainer.py
训练器
首先找到train.py文件
def train(**kwargs):#rows 50
faster_rcnn = FasterRCNNVGG16()
trainer = FasterRCNNTrainer(faster_rcnn).cuda()
....
trainer.train_step(img, bbox, label, scale)#rows 81
通过这几行代码,我们可以知道,训练器tranier是我们主要查看对象,而对img和后面跟着的bbox和label应该是对应的ground truth,而FasterRCNNVGG16里面应该包含对应主体网络,而train_step里面应该包含前向和反向传播过程。
所以我们重点看trainer
在trainer.py中首先查看train_step的实现
def train_step(self, imgs, bboxes, labels, scale):#rows 167
self.optimizer.zero_grad()
losses = self.forward(imgs, bboxes, labels, scale)
losses.total_loss.backward()
self.optimizer.step()
self.update_meters(losses)
return losses
果然包含其前向和反向传播的过程,我们这里重点看forward过程
class FasterRCNNTrainer(nn.Module):#rows 25
def __init__(self, faster_rcnn):#rows 42
self.faster_rcnn = faster_rcnn
...
def forward(self, imgs, bboxes, labels, scale):#rows 65
...
features = self.faster_rcnn.extractor(imgs)#rows 97
rpn_locs, rpn_scores, rois, roi_indices, anchor = \
self.faster_rcnn.rpn(features, img_size, scale)
...
sample_roi, gt_roi_loc, gt_roi_label = self.proposal_target_creator(...)#rows 112
...
roi_cls_loc, roi_score = self.faster_rcnn.head(...)#rows 120
gt_rpn_loc, gt_rpn_label = self.anchor_target_creator(...)#rows 126
####
后面是损失函数
####
....
在forward函数中,我们就可以看到大致的网络传播结构,首先是feature map通过extractor来进行提取,然后生成rpn,然后对生成的bbox找到ground truth,然后对每个bbox再进行类型的预测,然后再计算损失,这就是简单的网络传播结构。
训练器中涉及到的主体网络和相关接口
通过上面的解析,我们有了整体的了解,然后我们从接口方面再进行剖析,首先上面的forward中涉及到了几个很重要的接口,首先是和网络相关的接口,
self.faster_rcnn.extractor()#rows 97
self.faster_rcnn.rpn()#rows 99
self.faster_rcnn.head()#rows 120
然后是和ground truth有关的接口,
self.proposal_target_creator()#rows 112
self.anchor_target_creator()#rows 126
因此我们就从这几个接口作为解析网络的入口。
主体网络
主体网络在model/faster_rcnn_vgg16.py中,我们从上面的接口开始分析网络,
extractor()
通过代码搜索,我们可以定位到extractor的位置
extractor, classifier = decom_vgg16()#rows 63
找到decom_vgg16函数
def decom_vgg16():#rows 12
...
model = vgg16(pretrained=False)#rows 15
...
features = list(model.features)[:30] #rows 21
classifier = model.classifier
....
return nn.Sequential(*features), classifier
通过上面这里的特征提取网络从预训练好的网络VGG16中提取出来的,通过代码我们也可以知道对应的层数。
rpn()
rpn()的位置:
rpn = RegionProposalNetwork()#rows 65
可以知道这就是我们上一篇博客中提到的RPN网络,在model/region_proposal_network.py文件中,这个留作后续分析。
head()
head()的位置
head = VGG16RoIHead()#rows 72
找到VGG16RoIHead类
class VGG16RoIHead(nn.Module):#rows 86
...
def forward(self, x, rois, roi_indices):#rows 117
...
fc7 = self.classifier(pool) #rows 135
roi_cls_locs = self.cls_loc(fc7)
roi_scores = self.score(fc7)
return roi_cls_locs, roi_scores
从上面代码我们可以知道这个就是对给定的ROI预测其位置和分类的网络,也就是主体网络最后用到的那个网络。
Ground Truth
知道了相关的网络接口,那么如果求损失,我们需要知道grond truth,怎么获取呢?我们需要对上面的grond truth的接口进行分析:
在trainer.py中进行定位
self.anchor_target_creator = AnchorTargetCreator()#rows 50
self.proposal_target_creator = ProposalTargetCreator()
我们很快找到其都包含在model/utils/creator_tool.py文件中
anchor_target_creator
class AnchorTargetCreator(object):#rows 136
def __call__(self, bbox, anchor, img_size):#rows 170
...
return loc, label
通过对代码分析,我们可以了解到其可以对输入的anchor产生器对应真实的loc和label
proposal_target_creator
class ProposalTargetCreator(object):#rows 8
def __call__(self, roi, bbox, label,...)#rows 43
...
return sample_roi, gt_roi_loc, gt_roi_label
通过观测该函数我们可以知道其是对给定的roi产生真实的bbox位置和标签。
损失函数
通过对上面分析,特别是对trainer.py中的forward函数的观察,我们可以知道,后面这两个ground truth生成器都是对self.faster_rcnn.rpn网络中的求得的结果找到其对应的ground truth,进而求损失:
def forward(self, imgs, bboxes, labels, scale):#rows 65
...
rpn_locs, rpn_scores, rois, roi_indices, anchor = \
self.faster_rcnn.rpn(features, img_size, scale)#rows 99
...
roi = rois #rows 107
...
sample_roi, gt_roi_loc, gt_roi_label = self.proposal_target_creator(
roi,
at.tonumpy(bbox),
at.tonumpy(label),...)#rows 112
...
gt_rpn_loc, gt_rpn_label = self.anchor_target_creator(
at.tonumpy(bbox),
anchor,
img_size)
...
roi_cls_loc, roi_score = self.faster_rcnn.head(
features,
sample_roi,
sample_roi_index)#rows 120
...
roi_cls_loss = nn.CrossEntropyLoss()(roi_score, gt_roi_label.cuda())#158
那么对于最后的分类网络的损失呢,就是head网络输出,然后求损失了,最后简易总结如下:
文章中有错误,还请您指出,一起进步。