目标检测特殊层:SSD目标检测之MultiBox代码解读

这篇博客主要写multibox_loss_layer,multibox_loss_layer也是SSD比较关键内容,主要包括内建了两个layer进行loss回归,还包括比如FindMatches,MineHardExamples,EncodeLocPrediction && EncodeConfPrediction等都是比较重要的函数(其中有一部分在bbox_util中,后面会介绍) 

代码:


   
   
  1. #include <algorithm>
  2. #include <map>
  3. #include <utility>
  4. #include <vector>
  5. #include “caffe/layers/multibox_loss_layer.hpp”
  6. #include “caffe/util/math_functions.hpp”
  7. namespace caffe {
  8. // layer setup,在这个函数里面还分别新建了两个layer用于loc回归和conf loss的计算
  9. template < typename Dtype>
  10. void MultiBoxLossLayer<Dtype>::LayerSetUp( const vector<Blob<Dtype>*>& bottom,
  11. const vector<Blob<Dtype>*>& top) {
  12. LossLayer<Dtype>::LayerSetUp(bottom, top);
  13. if ( this->layer_param_.propagate_down_size() == 0) {
  14. this->layer_param_.add_propagate_down( true); // 定位
  15. this->layer_param_.add_propagate_down( true); // 分类得分
  16. this->layer_param_.add_propagate_down( false); // prior
  17. this->layer_param_.add_propagate_down( false); // ground truth
  18. }
  19. const MultiBoxLossParameter& multibox_loss_param =
  20. this->layer_param_.multibox_loss_param();
  21. multibox_loss_param_ = this->layer_param_.multibox_loss_param(); // 这句话多余吧?
  22. num_ = bottom[ 0]->num(); // batch size
  23. num_priors_ = bottom[ 2]->height() / 4; // 先验的个数,每个先验包含左上角和右下角的点坐标
  24. // Get other parameters.
  25. CHECK(multibox_loss_param.has_num_classes()) << “Must provide num_classes.”;
  26. num_classes_ = multibox_loss_param.num_classes(); // 类别个数
  27. CHECK_GE(num_classes_, 1) << “num_classes should not be less than 1.”;
  28. share_location_ = multibox_loss_param.share_location(); // 共享类别位置预测 default = true
  29. loc_classes_ = share_location_ ? 1 : num_classes_; // 如果shared表示所有的类别同用一个location prediction,否则每一类各自预测。还不是很懂这样做的原因
  30. background_label_id_ = multibox_loss_param.background_label_id(); // background的id
  31. use_difficult_gt_ = multibox_loss_param.use_difficult_gt(); // 是否使用difficutlt的ground truth,这个具体是什么还有待考虑
  32. mining_type_ = multibox_loss_param.mining_type(); // 这里跟老版SSD代码有些许不同
  33. if (multibox_loss_param.has_do_neg_mining()) {
  34. LOG(WARNING) << “do_neg_mining is deprecated, use mining_type instead.”;
  35. do_neg_mining_ = multibox_loss_param.do_neg_mining(); // 难例挖掘 true
  36. CHECK_EQ(do_neg_mining_,
  37. mining_type_ != MultiBoxLossParameter_MiningType_NONE); // MultiBoxLossParameter_MiningType_NONE变量?还不清楚具体的用法
  38. }
  39. do_neg_mining_ = mining_type_ != MultiBoxLossParameter_MiningType_NONE;
  40. if (! this->layer_param_.loss_param().has_normalization() && // loss normalization,出自LossParameter,默认VALID
  41. this->layer_param_.loss_param().has_normalize()) {
  42. normalization_ = this->layer_param_.loss_param().normalize() ?
  43. LossParameter_NormalizationMode_VALID :
  44. LossParameter_NormalizationMode_BATCH_SIZE;
  45. } else {
  46. normalization_ = this->layer_param_.loss_param().normalization();
  47. }
  48. if (do_neg_mining_) {
  49. CHECK(share_location_)
  50. << “Currently only support negative mining if share_location is true.”;
  51. }
  52. vector< int> loss_shape( 1, 1);
  53. // Set up localization loss layer. // 定位loss
  54. loc_weight_ = multibox_loss_param.loc_weight(); // loc weight 1.0
  55. loc_loss_type_ = multibox_loss_param.loc_loss_type(); // loss 类型 SMOOTH_L1
  56. // fake shape.
  57. vector< int> loc_shape( 1, 1); // 1维
  58. loc_shape.push_back( 4); // 1,4
  59. loc_pred_.Reshape(loc_shape); // 1*2 [1,4]
  60. loc_gt_.Reshape(loc_shape); // [1,4]
  61. loc_bottom_vec_.push_back(&loc_pred_); // 存放前面的指针
  62. loc_bottom_vec_.push_back(&loc_gt_); // 存放gt的指针
  63. loc_loss_.Reshape(loss_shape); // location的loss [1,4]
  64. loc_top_vec_.push_back(&loc_loss_); // 存放top的指针
  65. if (loc_loss_type_ == MultiBoxLossParameter_LocLossType_L2) { // 新建一个层,实现对locationloss的计算
  66. LayerParameter layer_param;
  67. layer_param.set_name( this->layer_param_.name() + “_l2_loc”);
  68. layer_param.set_type( “EuclideanLoss”);
  69. layer_param.add_loss_weight(loc_weight_);
  70. loc_loss_layer_ = LayerRegistry<Dtype>::CreateLayer(layer_param);
  71. loc_loss_layer_->SetUp(loc_bottom_vec_, loc_top_vec_);
  72. } else if (loc_loss_type_ == MultiBoxLossParameter_LocLossType_SMOOTH_L1) { // SMOOTH_L1,SSD是选这个
  73. LayerParameter layer_param;
  74. layer_param.set_name( this->layer_param_.name() + “_smooth_L1_loc”); // mbox_loss_smooth_L1_loc
  75. layer_param.set_type( “SmoothL1Loss”);
  76. layer_param.add_loss_weight(loc_weight_); // 1.0
  77. loc_loss_layer_ = LayerRegistry<Dtype>::CreateLayer(layer_param); // 创建layer
  78. loc_loss_layer_->SetUp(loc_bottom_vec_, loc_top_vec_); //送入推断和gt,输出loc_loss,有一点不太清楚loc_bottom_vec_是两个地址,后面怎么弄?
  79. } else {
  80. LOG(FATAL) << “Unknown localization loss type.”;
  81. }
  82. // Set up confidence loss layer.
  83. // 新建一个层,实现的是对confidence loss的计算
  84. conf_loss_type_ = multibox_loss_param.conf_loss_type(); // SOFTMAX
  85. conf_bottom_vec_.push_back(&conf_pred_); // conf_pred_ 是blob
  86. conf_bottom_vec_.push_back(&conf_gt_); // conf_gt_ 是blob
  87. conf_loss_.Reshape(loss_shape); // [1,4]
  88. conf_top_vec_.push_back(&conf_loss_); // 也是一维向量
  89. if (conf_loss_type_ == MultiBoxLossParameter_ConfLossType_SOFTMAX) {
  90. CHECK_GE(background_label_id_, 0)
  91. << “background_label_id should be within [0, num_classes) for Softmax.”;
  92. CHECK_LT(background_label_id_, num_classes_)
  93. << “background_label_id should be within [0, num_classes) for Softmax.”;
  94. LayerParameter layer_param;
  95. layer_param.set_name( this->layer_param_.name() + “_softmax_conf”); // mbox_loss_softmax_conf
  96. layer_param.set_type( “SoftmaxWithLoss”);
  97. layer_param.add_loss_weight(Dtype( 1.)); // 1.0
  98. layer_param.mutable_loss_param()->set_normalization(
  99. LossParameter_NormalizationMode_NONE);
  100. SoftmaxParameter* softmax_param = layer_param.mutable_softmax_param();
  101. softmax_param->set_axis( 1);
  102. // Fake reshape.
  103. vector< int> conf_shape( 1, 1);
  104. conf_gt_.Reshape(conf_shape); // [1]
  105. conf_shape.push_back(num_classes_); // 这两个参数没有用到
  106. conf_pred_.Reshape(conf_shape);
  107. conf_loss_layer_ = LayerRegistry<Dtype>::CreateLayer(layer_param);
  108. conf_loss_layer_->SetUp(conf_bottom_vec_, conf_top_vec_);
  109. } else if (conf_loss_type_ == MultiBoxLossParameter_ConfLossType_LOGISTIC) {
  110. LayerParameter layer_param;
  111. layer_param.set_name( this->layer_param_.name() + “_logistic_conf”);
  112. layer_param.set_type( “SigmoidCrossEntropyLoss”);
  113. layer_param.add_loss_weight(Dtype( 1.));
  114. // Fake reshape.
  115. vector< int> conf_shape( 1, 1);
  116. conf_shape.push_back(num_classes_);
  117. conf_gt_.Reshape(conf_shape);
  118. conf_pred_.Reshape(conf_shape);
  119. conf_loss_layer_ = LayerRegistry<Dtype>::CreateLayer(layer_param);
  120. conf_loss_layer_->SetUp(conf_bottom_vec_, conf_top_vec_);
  121. } else {
  122. LOG(FATAL) << “Unknown confidence loss type.”;
  123. }
  124. } // layer setup 结束
  125. template < typename Dtype>
  126. void MultiBoxLossLayer<Dtype>::Reshape( const vector<Blob<Dtype>*>& bottom,
  127. const vector<Blob<Dtype>*>& top) {
  128. LossLayer<Dtype>::Reshape(bottom, top);
  129. num_ = bottom[ 0]->num(); // batch num
  130. num_priors_ = bottom[ 2]->height() / 4; // 这里的blob维度还需要再仔细分析下
  131. num_gt_ = bottom[ 3]->height();
  132. CHECK_EQ(bottom[ 0]->num(), bottom[ 1]->num());
  133. CHECK_EQ(num_priors_ * loc_classes_ * 4, bottom[ 0]->channels()) // loc_classes_共享是1,不共享就是classes数
  134. << “Number of priors must match number of location predictions.”;
  135. CHECK_EQ(num_priors_ * num_classes_, bottom[ 1]->channels())
  136. << “Number of priors must match number of confidence predictions.”;
  137. }
  138. // 预测loction bottom[0] dimension is [N*C*1*1],confidence bottom[1] dimension is [N*C*1*1]
  139. // priors bottom[2] dimension is [N*1*2*W], gound truth bottom[3] dimension is [N*1*H*8]
  140. template < typename Dtype>
  141. void MultiBoxLossLayer<Dtype>::Forward_cpu( const vector<Blob<Dtype>*>& bottom,
  142. const vector<Blob<Dtype>*>& top) {
  143. const Dtype* loc_data = bottom[ 0]->cpu_data();
  144. const Dtype* conf_data = bottom[ 1]->cpu_data();
  145. const Dtype* prior_data = bottom[ 2]->cpu_data();
  146. const Dtype* gt_data = bottom[ 3]->cpu_data();
  147. // Retrieve all ground truth.
  148. /*
  149. message NormalizedBBox {
  150. optional float xmin = 1;
  151. optional float ymin = 2;
  152. optional float xmax = 3;
  153. optional float ymax = 4;
  154. optional int32 label = 5;
  155. optional bool difficult = 6;
  156. optional float score = 7;
  157. optional float size = 8;
  158. }
  159. */
  160. // Retrieve all ground truth.
  161. map< int, vector<NormalizedBBox> > all_gt_bboxes; //转化ground truth bounding box,存放在all_gt_bboxes
  162. GetGroundTruth(gt_data, num_gt_, background_label_id_, use_difficult_gt_, // background_label_id_=0,use_difficult_gt_=true
  163. &all_gt_bboxes);
  164. // Retrieve all prior bboxes. It is same within a batch since we assume all
  165. // images in a batch are of same dimension.
  166. // 把prior box 存入prior_bboxes,把variances存入prior_variances
  167. vector<NormalizedBBox> prior_bboxes;
  168. vector< vector< float> > prior_variances;
  169. GetPriorBBoxes(prior_data, num_priors_, &prior_bboxes, &prior_variances);
  170. // Retrieve all predictions.
  171. vector<LabelBBox> all_loc_preds; // map<int, vector<NormalizedBBox> > LabelBBox;
  172. GetLocPredictions(loc_data, num_, num_priors_, loc_classes_, share_location_,
  173. &all_loc_preds); // 这里是把所有预测的box写入了all_loc_preds,这些box就是bottom[0],loc_data
  174. // Find matches between source bboxes and ground truth bboxes.
  175. vector< map< int, vector< float> > > all_match_overlaps;
  176. FindMatches(all_loc_preds, all_gt_bboxes, prior_bboxes, prior_variances,
  177. multibox_loss_param_, &all_match_overlaps, &all_match_indices_);
  178. num_matches_ = 0;
  179. int num_negs = 0;
  180. // Sample hard negative (and positive) examples based on mining type.
  181. MineHardExamples(*bottom[ 1], all_loc_preds, all_gt_bboxes, prior_bboxes,
  182. prior_variances, all_match_overlaps, multibox_loss_param_,
  183. &num_matches_, &num_negs, &all_match_indices_,
  184. &all_neg_indices_);
  185. if (num_matches_ >= 1) {
  186. // Form data to pass on to loc_loss_layer_.
  187. vector< int> loc_shape( 2);
  188. loc_shape[ 0] = 1;
  189. loc_shape[ 1] = num_matches_ * 4;
  190. loc_pred_.Reshape(loc_shape); // 地址已经存放进了loc_bottom_vec_
  191. loc_gt_.Reshape(loc_shape);
  192. Dtype* loc_pred_data = loc_pred_.mutable_cpu_data();
  193. Dtype* loc_gt_data = loc_gt_.mutable_cpu_data();
  194. EncodeLocPrediction(all_loc_preds, all_gt_bboxes, all_match_indices_,
  195. prior_bboxes, prior_variances, multibox_loss_param_,
  196. loc_pred_data, loc_gt_data);
  197. loc_loss_layer_->Reshape(loc_bottom_vec_, loc_top_vec_);
  198. loc_loss_layer_->Forward(loc_bottom_vec_, loc_top_vec_); // 前向计算
  199. } else {
  200. loc_loss_.mutable_cpu_data()[ 0] = 0;
  201. } // 这里完成loc的loss前向计算
  202. // Form data to pass on to conf_loss_layer_.
  203. if (do_neg_mining_) { // 计算positive和negative样本
  204. num_conf_ = num_matches_ + num_negs;
  205. } else {
  206. num_conf_ = num_ * num_priors_;
  207. }
  208. if (num_conf_ >= 1) {
  209. // Reshape the confidence data.
  210. vector< int> conf_shape;
  211. if (conf_loss_type_ == MultiBoxLossParameter_ConfLossType_SOFTMAX) { // 选softmax
  212. conf_shape.push_back(num_conf_);
  213. conf_gt_.Reshape(conf_shape);
  214. conf_shape.push_back(num_classes_);
  215. conf_pred_.Reshape(conf_shape);
  216. } else if (conf_loss_type_ == MultiBoxLossParameter_ConfLossType_LOGISTIC) {
  217. conf_shape.push_back( 1);
  218. conf_shape.push_back(num_conf_);
  219. conf_shape.push_back(num_classes_);
  220. conf_gt_.Reshape(conf_shape);
  221. conf_pred_.Reshape(conf_shape);
  222. } else {
  223. LOG(FATAL) << “Unknown confidence loss type.”;
  224. }
  225. if (!do_neg_mining_) {
  226. // Consider all scores.
  227. // Share data and diff with bottom[1].
  228. CHECK_EQ(conf_pred_.count(), bottom[ 1]->count());
  229. conf_pred_.ShareData(*(bottom[ 1]));
  230. }
  231. Dtype* conf_pred_data = conf_pred_.mutable_cpu_data();
  232. Dtype* conf_gt_data = conf_gt_.mutable_cpu_data();
  233. caffe_set(conf_gt_.count(), Dtype(background_label_id_), conf_gt_data);
  234. EncodeConfPrediction(conf_data, num_, num_priors_, multibox_loss_param_,
  235. all_match_indices_, all_neg_indices_, all_gt_bboxes,
  236. conf_pred_data, conf_gt_data);
  237. conf_loss_layer_->Reshape(conf_bottom_vec_, conf_top_vec_);
  238. conf_loss_layer_->Forward(conf_bottom_vec_, conf_top_vec_);
  239. } else {
  240. conf_loss_.mutable_cpu_data()[ 0] = 0;
  241. } // 这里结束conf的loss计算
  242. top[ 0]->mutable_cpu_data()[ 0] = 0;
  243. if ( this->layer_param_.propagate_down( 0)) { // true 正则化一下 loc_loss
  244. Dtype normalizer = LossLayer<Dtype>::GetNormalizer(
  245. normalization_, num_, num_priors_, num_matches_);
  246. top[ 0]->mutable_cpu_data()[ 0] +=
  247. loc_weight_ * loc_loss_.cpu_data()[ 0] / normalizer;
  248. }
  249. if ( this->layer_param_.propagate_down( 1)) { // true conf_loss
  250. Dtype normalizer = LossLayer<Dtype>::GetNormalizer(
  251. normalization_, num_, num_priors_, num_matches_);
  252. top[ 0]->mutable_cpu_data()[ 0] += conf_loss_.cpu_data()[ 0] / normalizer;
  253. }
  254. } // 结束Forward计算
  255. template < typename Dtype>
  256. void MultiBoxLossLayer<Dtype>::Backward_cpu( const vector<Blob<Dtype>*>& top,
  257. const vector< bool>& propagate_down,
  258. const vector<Blob<Dtype>*>& bottom) {
  259. if (propagate_down[ 2]) {
  260. LOG(FATAL) << this->type()
  261. << ” Layer cannot backpropagate to prior inputs.”;
  262. }
  263. if (propagate_down[ 3]) {
  264. LOG(FATAL) << this->type()
  265. << ” Layer cannot backpropagate to label inputs.”;
  266. }
  267. // Back propagate on location prediction.
  268. if (propagate_down[ 0]) { // 先回传 loc_loss
  269. Dtype* loc_bottom_diff = bottom[ 0]->mutable_cpu_diff();
  270. caffe_set(bottom[ 0]->count(), Dtype( 0), loc_bottom_diff);
  271. if (num_matches_ >= 1) {
  272. vector< bool> loc_propagate_down;
  273. // Only back propagate on prediction, not ground truth.
  274. loc_propagate_down.push_back( true);
  275. loc_propagate_down.push_back( false);
  276. loc_loss_layer_->Backward(loc_top_vec_, loc_propagate_down,
  277. loc_bottom_vec_);
  278. // Scale gradient.
  279. Dtype normalizer = LossLayer<Dtype>::GetNormalizer(
  280. normalization_, num_, num_priors_, num_matches_);
  281. Dtype loss_weight = top[ 0]->cpu_diff()[ 0] / normalizer;
  282. caffe_scal(loc_pred_.count(), loss_weight, loc_pred_.mutable_cpu_diff());
  283. // Copy gradient back to bottom[0].
  284. const Dtype* loc_pred_diff = loc_pred_.cpu_diff();
  285. int count = 0;
  286. for ( int i = 0; i < num_; ++i) {
  287. for ( map< int, vector< int> >::iterator it =
  288. all_match_indices_[i].begin();
  289. it != all_match_indices_[i].end(); ++it) {
  290. const int label = share_location_ ? 0 : it->first;
  291. const vector< int>& match_index = it->second;
  292. for ( int j = 0; j < match_index.size(); ++j) {
  293. if (match_index[j] <= -1) {
  294. continue;
  295. }
  296. // Copy the diff to the right place.
  297. int start_idx = loc_classes_ * 4 * j + label * 4;
  298. caffe_copy<Dtype>( 4, loc_pred_diff + count * 4,
  299. loc_bottom_diff + start_idx);
  300. ++count;
  301. }
  302. }
  303. loc_bottom_diff += bottom[ 0]->offset( 1);
  304. }
  305. }
  306. }
  307. // Back propagate on confidence prediction.
  308. if (propagate_down[ 1]) {
  309. Dtype* conf_bottom_diff = bottom[ 1]->mutable_cpu_diff();
  310. caffe_set(bottom[ 1]->count(), Dtype( 0), conf_bottom_diff);
  311. if (num_conf_ >= 1) {
  312. vector< bool> conf_propagate_down;
  313. // Only back propagate on prediction, not ground truth.
  314. conf_propagate_down.push_back( true);
  315. conf_propagate_down.push_back( false);
  316. conf_loss_layer_->Backward(conf_top_vec_, conf_propagate_down,
  317. conf_bottom_vec_);
  318. // Scale gradient.
  319. Dtype normalizer = LossLayer<Dtype>::GetNormalizer(
  320. normalization_, num_, num_priors_, num_matches_);
  321. Dtype loss_weight = top[ 0]->cpu_diff()[ 0] / normalizer;
  322. caffe_scal(conf_pred_.count(), loss_weight,
  323. conf_pred_.mutable_cpu_diff());
  324. // Copy gradient back to bottom[1].
  325. const Dtype* conf_pred_diff = conf_pred_.cpu_diff();
  326. if (do_neg_mining_) {
  327. int count = 0;
  328. for ( int i = 0; i < num_; ++i) {
  329. // Copy matched (positive) bboxes scores’ diff.
  330. const map< int, vector< int> >& match_indices = all_match_indices_[i];
  331. for ( map< int, vector< int> >::const_iterator it =
  332. match_indices.begin(); it != match_indices.end(); ++it) {
  333. const vector< int>& match_index = it->second;
  334. CHECK_EQ(match_index.size(), num_priors_);
  335. for ( int j = 0; j < num_priors_; ++j) {
  336. if (match_index[j] <= -1) {
  337. continue;
  338. }
  339. // Copy the diff to the right place.
  340. caffe_copy<Dtype>(num_classes_,
  341. conf_pred_diff + count * num_classes_,
  342. conf_bottom_diff + j * num_classes_);
  343. ++count;
  344. }
  345. }
  346. // Copy negative bboxes scores’ diff.
  347. for ( int n = 0; n < all_neg_indices_[i].size(); ++n) {
  348. int j = all_neg_indices_[i][n];
  349. CHECK_LT(j, num_priors_);
  350. caffe_copy<Dtype>(num_classes_,
  351. conf_pred_diff + count * num_classes_,
  352. conf_bottom_diff + j * num_classes_);
  353. ++count;
  354. }
  355. conf_bottom_diff += bottom[ 1]->offset( 1);
  356. }
  357. } else {
  358. // The diff is already computed and stored.
  359. bottom[ 1]->ShareDiff(conf_pred_);
  360. }
  361. }
  362. }
  363. // After backward, remove match statistics.
  364. all_match_indices_.clear();
  365. all_neg_indices_.clear();
  366. }
  367. INSTANTIATE_CLASS(MultiBoxLossLayer);
  368. REGISTER_LAYER_CLASS(MultiBoxLoss);
  369. } // namespace caffe

            </div>
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值