此代码位于 lib/model/faster_rcnn
这个文件夹下的resnet.py和vgg16.py是用训练好的模型提取图片特征。
重点看一下faster_rcnn.py 代码
1,初始化
1,classes ,具体什么作用还不知道
2,class_agnostic,是否专类
def init(self, classes, class_agnostic):
super(_fasterRCNN, self).init()
self.classes = classes
//类别的数量
self.n_classes = len(classes)
self.class_agnostic = class_agnostic
# 初始化分类loss和边框偏移loss
self.RCNN_loss_cls = 0
self.RCNN_loss_bbox = 0
# 定义rpn网络
self.RCNN_rpn = _RPN(self.dout_base_model)
self.RCNN_proposal_target = _ProposalTargetLayer(self.n_classes)
self.RCNN_roi_pool = _RoIPooling(cfg.POOLING_SIZE, cfg.POOLING_SIZE, 1.0/16.0)
self.RCNN_roi_align = RoIAlignAvg(cfg.POOLING_SIZE, cfg.POOLING_SIZE, 1.0/16.0)
self.grid_size = cfg.POOLING_SIZE * 2 if cfg.CROP_RESIZE_WITH_MAX_POOL else cfg.POOLING_SIZE
self.R