SSD源码解读1-数据层AnnotatedDataLayer

版权声明:本文为博主原创文章,转载需注明出处。 https://blog.csdn.net/qianqing13579/article/details/80146281

前言

年后到现在,利用自己的业余时间断断续续将caffe的SSD源码看完了,虽然中间由于工作原因暂停了一段时间,但最终还算顺利完成了,SSD源码的阅读也是今年的年度计划中比较重要的一项内容,完成了还是很有成就感的。阅读完代码后,一个最大的体会就是之前论文中很多困惑我的细节现在豁然开朗了,哈哈。

在阅读代码期间,每次遇到困惑我的地方,我会反复思考,琢磨,利用走路,吃饭的时间思考,也常常会在宿舍里来回踱步,现在我对阅读代码有了一个新的体会。当你阅读一段对你来说很难的代码的时候,不要害怕,你只要静下心来,将一段很难的代码拆分成N个子块,然后针对每个子块各个击破,等你将所有子块都击破了,然后再将所有子块串联起来连接成一个整体,再从整体思考这段代码,会有更加深刻的理解。当然这个过程起初会很难,因为起初很多东西你都不懂,就像我阅读SSD代码的时候,起初很难,要了解很多细节,但是只要有耐心有毅力,慢慢你会发现,你对这些内容越来越熟悉,你也会感到越来越轻松,直到最后你豁然开朗,发现这段很难的代码也不过如此,那种感觉实在是太美妙了。

五一的第一天稍微整理了一下SSD源码的阅读笔记,写成博客,与大家一起分享交流,由于SSD源码比较复杂,加上时间精力有限,不可能对每个细节都有深入的理解,博客中有不足之处,希望大家能够提出宝贵的意见。

这篇博客是SSD源码解读系列的第1篇,对数据层进行解读。

SSD源码阅读的时候,我对SSD源码创建了QT工程,这样方便阅读,SSD源码的QT工程我上传到CSDN了,该工程用QT可以直接打开的,大家可以直接下载该QT工程阅读,提高阅读效率。
点击下载


数据层AnnotatedDataLayer源码解读

#ifdef USE_OPENCV
#include <opencv2/core/core.hpp>
#endif  // USE_OPENCV
#include <stdint.h>

#include <algorithm>
#include <map>
#include <vector>

#include "caffe/data_transformer.hpp"
#include "caffe/layers/annotated_data_layer.hpp"
#include "caffe/util/benchmark.hpp"
#include "caffe/util/sampler.hpp"

namespace caffe {

template <typename Dtype>
AnnotatedDataLayer<Dtype>::AnnotatedDataLayer(const LayerParameter& param)
  : BasePrefetchingDataLayer<Dtype>(param),
    reader_(param) {
}

template <typename Dtype>
AnnotatedDataLayer<Dtype>::~AnnotatedDataLayer() {
  this->StopInternalThread();
}

template <typename Dtype>
void AnnotatedDataLayer<Dtype>::DataLayerSetUp(
    const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
  const int batch_size = this->layer_param_.data_param().batch_size();
  const AnnotatedDataParameter& anno_data_param = this->layer_param_.annotated_data_param();

  // 读取所有数据增强采样参数
  for (int i = 0; i < anno_data_param.batch_sampler_size(); ++i) 
  {
    batch_samplers_.push_back(anno_data_param.batch_sampler(i));
  }
  label_map_file_ = anno_data_param.label_map_file();

  // Make sure dimension is consistent within batch.
  const TransformationParameter& transform_param = this->layer_param_.transform_param();
  if (transform_param.has_resize_param()) 
  {
    if (transform_param.resize_param().resize_mode() ==ResizeParameter_Resize_mode_FIT_SMALL_SIZE) 
    {
      CHECK_EQ(batch_size, 1)<< "Only support batch size of 1 for FIT_SMALL_SIZE.";
    }
  }

  // 读取一个数据,并读取数据的shape,初始化top的shape和prefetch的shape(比如数据大小为300x300)
  // AnnotatedDatum包含了数据和标注(标注包含了label和bounding box)
  // Read a data point, and use it to initialize the top blob.
  AnnotatedDatum& anno_datum = *(reader_.full().peek()); // reader_中读到的数据就是输入的数据(包括图像数据和boundingbox坐标)
  // Use data_transformer to infer the expected blob shape from anno_datum.
  vector<int> top_shape =this->data_transformer_->InferBlobShape(anno_datum.datum());
  this->transformed_data_.Reshape(top_shape);
  // Reshape top[0] and prefetch_data according to the batch_size.
  top_shape[0] = batch_size;
  top[0]->Reshape(top_shape);

  // 预读线程中的图像数据
  for (int i = 0; i < this->PREFETCH_COUNT; ++i) 
  {
    this->prefetch_[i].data_.Reshape(top_shape);
  }
  LOG(INFO) << "output data size: " << top[0]->num() << ","<< top[0]->channels() << "," << top[0]->height() << ","<< top[0]->width();

  // label
  if (this->output_labels_) 
  {
    // 生成数据的时候是有类型的 anno_datum.set_type(AnnotatedDatum_AnnotationType_BBOX);
    has_anno_type_ = anno_datum.has_type() || anno_data_param.has_anno_type();
    vector<int> label_shape(4, 1);
    if (has_anno_type_) 
    {
      anno_type_ = anno_datum.type();
      if (anno_data_param.has_anno_type()) 
      {
        // If anno_type is provided in AnnotatedDataParameter, replace
        // the type stored in each individual AnnotatedDatum.
        LOG(WARNING) << "type stored in AnnotatedDatum is shadowed.";
        anno_type_ = anno_data_param.anno_type();
      }
      // Infer the label shape from anno_datum.AnnotationGroup().
      int num_bboxes = 0;

      // 读取该图像的所有box数量
      if (anno_type_ == AnnotatedDatum_AnnotationType_BBOX) 
      {
        // Since the number of bboxes can be different for each image,
        // we store the bbox information in a specific format. In specific:
        // All bboxes are stored in one spatial plane (num and channels are 1)
        // And each row contains one and only one box in the following format:
        // [item_id, group_label, instance_id, xmin, ymin, xmax, ymax, diff]
        // Note: Refer to caffe.proto for details about group_label and
        // instance_id.
        for (int g = 0; g < anno_datum.annotation_group_size(); ++g) {
          num_bboxes += anno_datum.annotation_group(g).annotation_size();
        }
        label_shape[0] = 1;
        label_shape[1] = 1;
        // BasePrefetchingDataLayer<Dtype>::LayerSetUp() requires to call
        // cpu_data and gpu_data for consistent prefetch thread. Thus we make
        // sure there is at least one bbox.
        label_shape[2] = std::max(num_bboxes, 1);
        label_shape[3] = 8;
      } 
      else 
      {
        LOG(FATAL) << "Unknown annotation type.";
      }
    } 
    else 
    {
      label_shape[0] = batch_size;
    }
    top[1]->Reshape(label_shape);

    // 预读线程中的label数据
    for (int i = 0; i < this->PREFETCH_COUNT; ++i) 
    {
      this->prefetch_[i].label_.Reshape(label_shape);
    }
  }
}

// This function is called on prefetch thread
template<typename Dtype>
void AnnotatedDataLayer<Dtype>::load_batch(Batch<Dtype>* batch) 
{
  CPUTimer batch_timer;
  batch_timer.Start();
  double read_time = 0;
  double trans_time = 0;
  CPUTimer timer;
  CHECK(batch->data_.count());
  CHECK(this->transformed_data_.count());

  // Reshape according to the first anno_datum of each batch
  // on single input batches allows for inputs of varying dimension.
  const int batch_size = this->layer_param_.data_param().batch_size();
  const AnnotatedDataParameter& anno_data_param =this->layer_param_.annotated_data_param();
  const TransformationParameter& transform_param =this->layer_param_.transform_param();

  // 初始化transformed_data_和 batch->data_的大小
  AnnotatedDatum& anno_datum = *(reader_.full().peek());
  vector<int> top_shape =this->data_transformer_->InferBlobShape(anno_datum.datum());// 3x300x300
  this->transformed_data_.Reshape(top_shape); // transformed_data_存储一幅图像,对于SSD300,transformed_data_大小为:[1,3,300,300]
  top_shape[0] = batch_size;
  batch->data_.Reshape(top_shape); // batch->data_存储batchsize个图像,对于SSD300,batch->data_大小为[batchsize,3,300,300]

  Dtype* top_data = batch->data_.mutable_cpu_data();
  Dtype* top_label = NULL;  // suppress warnings about uninitialized variables
  if (this->output_labels_ && !has_anno_type_) 
  {
    top_label = batch->label_.mutable_cpu_data();
  }

  // Store transformed annotation.
  map<int, vector<AnnotationGroup> > all_anno; // batchsize中每一幅图像以及对应的标注
  int num_bboxes = 0;

  for (int item_id = 0; item_id < batch_size; ++item_id) 
  {
    timer.Start();

    // 获取一幅图像,并做相应的预处理(比如加入扰动)
    AnnotatedDatum& anno_datum = *(reader_.full().pop("Waiting for data"));
    read_time += timer.MicroSeconds();
    timer.Start();
    AnnotatedDatum distort_datum;
    AnnotatedDatum* expand_datum = NULL;
    if (transform_param.has_distort_param()) 
    {
      distort_datum.CopyFrom(anno_datum);
      this->data_transformer_->DistortImage(anno_datum.datum(),
                                            distort_datum.mutable_datum());
      if (transform_param.has_expand_param()) 
      {
        expand_datum = new AnnotatedDatum();
        this->data_transformer_->ExpandImage(distort_datum, expand_datum);
      } 
      else 
      {
        expand_datum = &distort_datum;
      }
    } 
    else 
    {
      if (transform_param.has_expand_param()) 
      {
        expand_datum = new AnnotatedDatum();
        this->data_transformer_->ExpandImage(anno_datum, expand_datum);
      } 
      else 
      {
        expand_datum = &anno_datum;
      }
    }

    AnnotatedDatum* sampled_datum = NULL;
    bool has_sampled = false;


    if (batch_samplers_.size() > 0)
    {
      /* 1. 首先进行数据增强(对应论文2.2 Training部分的Data augmentation)
       * 对于batchsize中的每一幅图像,为每个采样器(batch_sampler)生成max_sample个boundingbox(候选框)
       * 每个采样器生成的boundingbox与目标的IOU=0.1,0.3,0.5,0.7,0.9,这个与论文的描述是一致的
       * 示例:
          batch_sampler
          {
            sampler
            {
              min_scale: 0.3
              max_scale: 1.0
              min_aspect_ratio: 0.5
              max_aspect_ratio: 2.0
            }
            sample_constraint
            {
              min_jaccard_overlap: 0.7
            }
            max_sample: 1
            max_trials: 50
          }
       *  对于该采样器,随机生成的满足条件的boundingbox与图像中任一目标的IOU>0.7
       *  注意:
       *    1. 生成的boundingbox坐标是归一化的坐标,这样不受resize的影响,目标检测的回归都是采用的这种形式(比如MTCNN)
       *    2. 随机生成boundingbox的时候,根据每个batch_sampler的参数:尺度,宽高比,每个采样器最多尝试max_trials次
       *
       */
      vector<NormalizedBBox> sampled_bboxes;// 生成的是归一化的坐标
      GenerateBatchSamples(*expand_datum, batch_samplers_, &sampled_bboxes);


      /*2. 从生成的所有bounding box中随机挑选一个bounding box
       * 裁剪出该bounding box对应的图像(大小就是sampled_bboxes[rand_idx]在原图中的大小)并计算该bounding box中所有目标的坐标以及类别
       * 注意:
       *    1. bounding box中目标的坐标=(原图中ground truth的坐标-该bounding box的坐标)/(bounding box的边长)
       *     这里groundtruth与boundingbox的坐标都相对于原图,在mtcnn中也是采用了该计算方式
       *
       */
      if (sampled_bboxes.size() > 0)
      {
        int rand_idx = caffe_rng_rand() % sampled_bboxes.size();
        sampled_datum = new AnnotatedDatum();
        this->data_transformer_->CropImage(*expand_datum,sampled_bboxes[rand_idx],sampled_datum);

        has_sampled = true;
      } 
      else 
      {
        sampled_datum = expand_datum;
      }
    }
    else 
    {
      sampled_datum = expand_datum;
    }
    CHECK(sampled_datum != NULL);
    timer.Start();
    vector<int> shape =this->data_transformer_->InferBlobShape(sampled_datum->datum());
    if (transform_param.has_resize_param()) 
    {
        // 不执行该部分
      if (transform_param.resize_param().resize_mode() ==ResizeParameter_Resize_mode_FIT_SMALL_SIZE) 
      {
        this->transformed_data_.Reshape(shape);
        batch->data_.Reshape(shape);
        top_data = batch->data_.mutable_cpu_data();
      } 
      else 
      {
        CHECK(std::equal(top_shape.begin() + 1, top_shape.begin() + 4,shape.begin() + 1));
      }
    } 
    else 
    {
      CHECK(std::equal(top_shape.begin() + 1, top_shape.begin() + 4,
            shape.begin() + 1));
    }
    // Apply data transformations (mirror, scale, crop...)
    int offset = batch->data_.offset(item_id);
    this->transformed_data_.set_cpu_data(top_data + offset);
    vector<AnnotationGroup> transformed_anno_vec;
    if (this->output_labels_) 
    {
      if (has_anno_type_) 
      {
        // Make sure all data have same annotation type.
        CHECK(sampled_datum->has_type()) << "Some datum misses AnnotationType.";
        if (anno_data_param.has_anno_type()) 
        {
          sampled_datum->set_type(anno_type_);
        } 
        else 
        {
          CHECK_EQ(anno_type_, sampled_datum->type()) <<
              "Different AnnotationType.";
        }

        // Transform datum and annotation_group at the same time
        transformed_anno_vec.clear();

        // AnnotatedDatum,Blob<float>,vector<AnnotationGroup>

        /* 3. 将crop出来的AnnotatedDatum转换为数据部分和标注部分
         *  数据部分会resize到数据层设置的大小(比如300x300)并保存到top[0]中
         *  标注是所有目标在图像中的坐标
         *
         * 注意:
         *  1. 这里的图像并不一定是原始crop的图像,如果transform_param有crop_size这个参数,原来crop出来的图像会再次crop的
         *  2. 由于这里对crop出来的图像进行了一次resize,所以如果生成lmdb的时候,进行resize会导致数据层对原图进行两次resize,
         *     这样有可能会影响到目标的宽高比,所以在SFD(Single Shot Scale-invariant Face Detector)中,对此处做了一点改进,即在第一步
         *     生成boundingbox的时候,保证每个boundingbox都是正方形,这样resize到300x300的时候就不会改变目标的宽高比
         */
        this->data_transformer_->Transform(*sampled_datum,&(this->transformed_data_),&transformed_anno_vec);
        if (anno_type_ == AnnotatedDatum_AnnotationType_BBOX) 
        {
          // Count the number of bboxes.
          // 计算该随机生成的bounding box中有多少目标
          for (int g = 0; g < transformed_anno_vec.size(); ++g) 
          {
            num_bboxes += transformed_anno_vec[g].annotation_size();
          }
        } 
        else 
        {
          LOG(FATAL) << "Unknown annotation type.";
        }

        // batchsize中第item_id个图像的标注
        all_anno[item_id] = transformed_anno_vec;
      } 
      else 
      {
        this->data_transformer_->Transform(sampled_datum->datum(),&(this->transformed_data_));
        // Otherwise, store the label from datum.
        CHECK(sampled_datum->datum().has_label()) << "Cannot find any label.";
        top_label[item_id] = sampled_datum->datum().label();
      }
    } 
    else 
    {
      this->data_transformer_->Transform(sampled_datum->datum(),&(this->transformed_data_));
    }
    // clear memory
    if (has_sampled) {
      delete sampled_datum;
    }
    if (transform_param.has_expand_param()) {
      delete expand_datum;
    }
    trans_time += timer.MicroSeconds();

    // 将读过的数据再放回去
    reader_.free().push(const_cast<AnnotatedDatum*>(&anno_datum));
  }

  // Store "rich" annotation if needed.
  /*4. 最后将标注信息保存到top[1]中,top[1]的shape:[1,1,numberOfBoxes,8]
   *每一行格式:[item_id, group_label, instance_id, xmin, ymin, xmax, ymax, diff]
   *这个8维向量表示的含义:batchsize个图像中的第item_id幅图像中的第group_label个类别下的第instance_id个box的坐标为[xmin, ymin, xmax, ymax]
   *
   */
  if (this->output_labels_ && has_anno_type_) 
  {
    vector<int> label_shape(4);
    if (anno_type_ == AnnotatedDatum_AnnotationType_BBOX) 
    {
      label_shape[0] = 1;
      label_shape[1] = 1;
      label_shape[3] = 8;
      if (num_bboxes == 0) 
      {
        // Store all -1 in the label.
        label_shape[2] = 1;
        batch->label_.Reshape(label_shape);
        caffe_set<Dtype>(8, -1, batch->label_.mutable_cpu_data());
      } 
      else 
      {

        // num_bboxes就是前面crop出来的所有图像中所有目标的数量
        label_shape[2] = num_bboxes;
        batch->label_.Reshape(label_shape);
        top_label = batch->label_.mutable_cpu_data();
        int idx = 0;

        // 遍历bachsizes中每一幅图像的label信息
        for (int item_id = 0; item_id < batch_size; ++item_id) 
        {
            // 第ite_id幅图像的label信息
          const vector<AnnotationGroup>& anno_vec = all_anno[item_id];
          for (int g = 0; g < anno_vec.size(); ++g) 
          {
            const AnnotationGroup& anno_group = anno_vec[g];

            for (int a = 0; a < anno_group.annotation_size(); ++a) 
            {
              const Annotation& anno = anno_group.annotation(a);
              const NormalizedBBox& bbox = anno.bbox();

              top_label[idx++] = item_id;
              top_label[idx++] = anno_group.group_label();
              top_label[idx++] = anno.instance_id();
              top_label[idx++] = bbox.xmin();
              top_label[idx++] = bbox.ymin();
              top_label[idx++] = bbox.xmax();
              top_label[idx++] = bbox.ymax();
              top_label[idx++] = bbox.difficult();
            }
          }
        }
      }
    }
    else
    {
      LOG(FATAL) << "Unknown annotation type.";
    }
  }
  timer.Stop();
  batch_timer.Stop();
  DLOG(INFO) << "Prefetch batch: " << batch_timer.MilliSeconds() << " ms.";
  DLOG(INFO) << "     Read time: " << read_time / 1000 << " ms.";
  DLOG(INFO) << "Transform time: " << trans_time / 1000 << " ms.";
}

INSTANTIATE_CLASS(AnnotatedDataLayer);
REGISTER_LAYER_CLASS(AnnotatedData);

}  // namespace caffe

 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262
  • 263
  • 264
  • 265
  • 266
  • 267
  • 268
  • 269
  • 270
  • 271
  • 272
  • 273
  • 274
  • 275
  • 276
  • 277
  • 278
  • 279
  • 280
  • 281
  • 282
  • 283
  • 284
  • 285
  • 286
  • 287
  • 288
  • 289
  • 290
  • 291
  • 292
  • 293
  • 294
  • 295
  • 296
  • 297
  • 298
  • 299
  • 300
  • 301
  • 302
  • 303
  • 304
  • 305
  • 306
  • 307
  • 308
  • 309
  • 310
  • 311
  • 312
  • 313
  • 314
  • 315
  • 316
  • 317
  • 318
  • 319
  • 320
  • 321
  • 322
  • 323
  • 324
  • 325
  • 326
  • 327
  • 328
  • 329
  • 330
  • 331
  • 332
  • 333
  • 334
  • 335
  • 336
  • 337
  • 338
  • 339
  • 340
  • 341
  • 342
  • 343
  • 344
  • 345
  • 346
  • 347
  • 348
  • 349
  • 350
  • 351
  • 352
  • 353
  • 354
  • 355
  • 356
  • 357
  • 358
  • 359
  • 360
  • 361
  • 362
  • 363
  • 364
  • 365
  • 366
  • 367
  • 368
  • 369
  • 370
  • 371
  • 372
  • 373
  • 374
  • 375
  • 376
  • 377
  • 378
  • 379
  • 380
  • 381
  • 382
  • 383
  • 384
  • 385
  • 386
  • 387
  • 388
  • 389
  • 390
  • 391
  • 392
  • 393
  • 394
  • 395
  • 396
  • 397
  • 398
  • 399
  • 400
  • 401
  • 402
  • 403
  • 404
  • 405
  • 406
  • 407
  • 408
  • 409
  • 410
  • 411
  • 412
  • 413
  • 414
  • 415
  • 416
  • 417
  • 418
  • 419
  • 420
  • 421
  • 422
  • 423
  • 424
  • 425
  • 426
  • 427
  • 428
  • 429
  • 430
  • 431
  • 432
  • 433
  • 434
  • 435

数据层中有几个比较重要的函数GenerateBatchSamples(),this->data_transformer_->CropImage(),this->data_transformer_->Transform(),下面对他们详细解读一下

GenerateBatchSamples

void GenerateBatchSamples(const AnnotatedDatum& anno_datum,
                          const vector<BatchSampler>& batch_samplers,
                          vector<NormalizedBBox>* sampled_bboxes) 
{
  sampled_bboxes->clear();

  // 获取groundtruth box
  vector<NormalizedBBox> object_bboxes;
  GroupObjectBBoxes(anno_datum, &object_bboxes); 

  // 对于每个采样器生成多个box
  for (int i = 0; i < batch_samplers.size(); ++i) 
  {
     // Use original image as the source for sampling.
    if (batch_samplers[i].use_original_image()) 
    {
      NormalizedBBox unit_bbox;
      unit_bbox.set_xmin(0);
      unit_bbox.set_ymin(0);
      unit_bbox.set_xmax(1);
      unit_bbox.set_ymax(1);
      GenerateSamples(unit_bbox,  // 单位box
                      object_bboxes,// ground truth box
                      batch_samplers[i], // 采样器
                      sampled_bboxes);
    }
  }
}

void GenerateSamples(const NormalizedBBox& source_bbox, // 单位box
                     const vector<NormalizedBBox>& object_bboxes, // object_bboxes就是该图像中所有的ground truth boxes
                     const BatchSampler& batch_sampler, // 采样器
                     vector<NormalizedBBox>* sampled_bboxes) 
{
  int found = 0;

  // 每个采样器batch_sampler都要尝试max_trials次
  for (int i = 0; i < batch_sampler.max_trials(); ++i) 
  {
      // 每个batch_sampler生成的boundingbox个数大于等于max_sample了,就跳出
    if (batch_sampler.has_max_sample() && found >= batch_sampler.max_sample()) 
    {
      break;
    }

    // Generate sampled_bbox in the normalized space [0, 1].
    // 随机生成一个box
    NormalizedBBox sampled_bbox;
    SampleBBox(batch_sampler.sampler(), &sampled_bbox);

    // Transform the sampled_bbox w.r.t. source_bbox.
    // 转换为在单位box中的坐标,由于都是单位box,所以转换后还是自己
    LocateBBox(source_bbox, sampled_bbox, &sampled_bbox);

    // Determine if the sampled bbox is positive or negative by the constraint.
    // 所有的ground truth 与生成的boundingbox计算IOU,是否满足条件
    if (SatisfySampleConstraint(sampled_bbox, object_bboxes,batch_sampler.sample_constraint())) 
    {
      ++found;
      sampled_bboxes->push_back(sampled_bbox);
    }
  }
}
 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63

DataTransformer::CropImage

template<typename Dtype>
void DataTransformer<Dtype>::CropImage(const AnnotatedDatum& anno_datum,
                                       const NormalizedBBox& bbox,
                                       AnnotatedDatum* cropped_anno_datum)
{
  // 首先crop数据:将bbox映射到原图的像素坐标并crop出该区域对应的图像,然后将crop出来的图像保存到cropped_anno_datum
  CropImage(anno_datum.datum(), bbox, cropped_anno_datum->mutable_datum());
  cropped_anno_datum->set_type(anno_datum.type());

  // 根据crop_bbox 转换 annotation
  // cropped_anno_datum保持的就是该图像中每个类别下所有bbox与该crop_bbox的偏移(用(ground truth-crop_bbox)/crop_bbox)
  // Transform the annotation according to crop_bbox.
  const bool do_resize = false;
  const bool do_mirror = false;
  NormalizedBBox crop_bbox;
  ClipBBox(bbox, &crop_bbox); // 边界判断
  TransformAnnotation(anno_datum, do_resize, crop_bbox, do_mirror,cropped_anno_datum->mutable_annotation_group());
}

template<typename Dtype>
void DataTransformer<Dtype>::TransformAnnotation(
    const AnnotatedDatum& anno_datum, const bool do_resize,
    const NormalizedBBox& crop_bbox, const bool do_mirror,
    RepeatedPtrField<AnnotationGroup>* transformed_anno_group_all) 
{
  const int img_height = anno_datum.datum().height();
  const int img_width = anno_datum.datum().width();
  if (anno_datum.type() == AnnotatedDatum_AnnotationType_BBOX) 
  {
    // Go through each AnnotationGroup.
    // 计算每个类别下所有bbox与该crop_bbox的偏移,其实就算计算随机生成的box中所有目标的坐标
    for (int g = 0; g < anno_datum.annotation_group_size(); ++g) 
    {
      const AnnotationGroup& anno_group = anno_datum.annotation_group(g);
      AnnotationGroup transformed_anno_group;
      bool has_valid_annotation = false;

      // 每个类别的所有Annotation
      for (int a = 0; a < anno_group.annotation_size(); ++a) 
      {
        const Annotation& anno = anno_group.annotation(a);
        const NormalizedBBox& bbox = anno.bbox();

        // Adjust bounding box annotation.
        NormalizedBBox resize_bbox = bbox;

        // 这里do_resize和do_mirror都是false
        if (do_resize && param_.has_resize_param()) 
        {
          CHECK_GT(img_height, 0);
          CHECK_GT(img_width, 0);
          UpdateBBoxByResizePolicy(param_.resize_param(), img_width, img_height,&resize_bbox);
        }
        if (param_.has_emit_constraint() &&!MeetEmitConstraint(crop_bbox, resize_bbox,param_.emit_constraint())) 
        {
          continue;
        }
        // ProjectBBox计算ground truth与随机生成的bbox的偏移(只计算有交集的)
        NormalizedBBox proj_bbox; // proj_bbox 就是偏移,就是做回归用的
        if (ProjectBBox(crop_bbox, resize_bbox, &proj_bbox))
        {
          has_valid_annotation = true;
          Annotation* transformed_anno =transformed_anno_group.add_annotation();
          transformed_anno->set_instance_id(anno.instance_id());
          NormalizedBBox* transformed_bbox = transformed_anno->mutable_bbox();
          transformed_bbox->CopyFrom(proj_bbox);

          if (do_mirror) 
          {
            Dtype temp = transformed_bbox->xmin();
            transformed_bbox->set_xmin(1 - transformed_bbox->xmax());
            transformed_bbox->set_xmax(1 - temp);
          }
          if (do_resize && param_.has_resize_param()) 
          {
            ExtrapolateBBox(param_.resize_param(), img_height, img_width,crop_bbox, transformed_bbox);
          }
        }
      }
      // Save for output.
      if (has_valid_annotation)
      {
        // 遍历完该类别下所有ground truth,设置label
        transformed_anno_group.set_group_label(anno_group.group_label());
        transformed_anno_group_all->Add()->CopyFrom(transformed_anno_group);
      }
    }
  }
  else
  {
    LOG(FATAL) << "Unknown annotation type.";
  }
}
 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93

DataTransformer::Transform

template<typename Dtype>
void DataTransformer<Dtype>::Transform(
    const AnnotatedDatum& anno_datum, Blob<Dtype>* transformed_blob,
    vector<AnnotationGroup>* transformed_anno_vec) 
{
  bool do_mirror;
  Transform(anno_datum, transformed_blob, transformed_anno_vec, &do_mirror);
}

template<typename Dtype>
void DataTransformer<Dtype>::Transform(
    const AnnotatedDatum& anno_datum, Blob<Dtype>* transformed_blob,
    vector<AnnotationGroup>* transformed_anno_vec, bool* do_mirror) {
  RepeatedPtrField<AnnotationGroup> transformed_anno_group_all;
  Transform(anno_datum, transformed_blob, &transformed_anno_group_all,
            do_mirror);
  for (int g = 0; g < transformed_anno_group_all.size(); ++g) {
    transformed_anno_vec->push_back(transformed_anno_group_all.Get(g));
  }
}

template<typename Dtype>
void DataTransformer<Dtype>::Transform(
    const AnnotatedDatum& anno_datum, Blob<Dtype>* transformed_blob,
    RepeatedPtrField<AnnotationGroup>* transformed_anno_group_all,
    bool* do_mirror) 
{

  // Transform datum.
  /* 转换数据
   * 如果DataTransformer参数中没有crop_size,则crop_box还是原图大小(归一化大小,这里就是(0,0,1,1))
   * 如果数据层有resize参数,就会缩放
   *
   */
  const Datum& datum = anno_datum.datum();
  NormalizedBBox crop_bbox;
  Transform(datum, transformed_blob, &crop_bbox, do_mirror);


  // Transform annotation.
  /* 转换标注
   * 计算最后cropped出来的图像中所有目标的坐标(transform_param如果有crop_size,会crop出一块区域)
  */
  const bool do_resize = true;
  TransformAnnotation(anno_datum, do_resize, crop_bbox, *do_mirror,
                      transformed_anno_group_all);
}

 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48

数据层的源码大概就是这样,大家有什么疑问的,可以留言一起讨论。

2018-4-29 22:44:01
Last updated: 2018-5-1 10:53:20


非常感谢您的阅读,如果您觉得这篇文章对您有帮助,欢迎扫码进行赞赏。
这里写图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值