前言
年后到现在,利用自己的业余时间断断续续将caffe的SSD源码看完了,虽然中间由于工作原因暂停了一段时间,但最终还算顺利完成了,SSD源码的阅读也是今年的年度计划中比较重要的一项内容,完成了还是很有成就感的。阅读完代码后,一个最大的体会就是之前论文中很多困惑我的细节现在豁然开朗了,哈哈。
在阅读代码期间,每次遇到困惑我的地方,我会反复思考,琢磨,利用走路,吃饭的时间思考,也常常会在宿舍里来回踱步,现在我对阅读代码有了一个新的体会。当你阅读一段对你来说很难的代码的时候,不要害怕,你只要静下心来,将一段很难的代码拆分成N个子块,然后针对每个子块各个击破,等你将所有子块都击破了,然后再将所有子块串联起来连接成一个整体,再从整体思考这段代码,会有更加深刻的理解。当然这个过程起初会很难,因为起初很多东西你都不懂,就像我阅读SSD代码的时候,起初很难,要了解很多细节,但是只要有耐心有毅力,慢慢你会发现,你对这些内容越来越熟悉,你也会感到越来越轻松,直到最后你豁然开朗,发现这段很难的代码也不过如此,那种感觉实在是太美妙了。
五一的第一天稍微整理了一下SSD源码的阅读笔记,写成博客,与大家一起分享交流,由于SSD源码比较复杂,加上时间精力有限,不可能对每个细节都有深入的理解,博客中有不足之处,希望大家能够提出宝贵的意见。
这篇博客是SSD源码解读系列的第1篇,对数据层进行解读。
SSD源码阅读的时候,我对SSD源码创建了QT工程,这样方便阅读,SSD源码的QT工程我上传到CSDN了,该工程用QT可以直接打开的,大家可以直接下载该QT工程阅读,提高阅读效率。
点击下载
数据层AnnotatedDataLayer源码解读
#ifdef USE_OPENCV
#include <opencv2/core/core.hpp>
#endif // USE_OPENCV
#include <stdint.h>
#include <algorithm>
#include <map>
#include <vector>
#include "caffe/data_transformer.hpp"
#include "caffe/layers/annotated_data_layer.hpp"
#include "caffe/util/benchmark.hpp"
#include "caffe/util/sampler.hpp"
namespace caffe {
template <typename Dtype>
AnnotatedDataLayer<Dtype>::AnnotatedDataLayer(const LayerParameter& param)
: BasePrefetchingDataLayer<Dtype>(param),
reader_(param) {
}
template <typename Dtype>
AnnotatedDataLayer<Dtype>::~AnnotatedDataLayer() {
this->StopInternalThread();
}
template <typename Dtype>
void AnnotatedDataLayer<Dtype>::DataLayerSetUp(
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
const int batch_size = this->layer_param_.data_param().batch_size();
const AnnotatedDataParameter& anno_data_param = this->layer_param_.annotated_data_param();
// 读取所有数据增强采样参数
for (int i = 0; i < anno_data_param.batch_sampler_size(); ++i)
{
batch_samplers_.push_back(anno_data_param.batch_sampler(i));
}
label_map_file_ = anno_data_param.label_map_file();
// Make sure dimension is consistent within batch.
const TransformationParameter& transform_param = this->layer_param_.transform_param();
if (transform_param.has_resize_param())
{
if (transform_param.resize_param().resize_mode() ==ResizeParameter_Resize_mode_FIT_SMALL_SIZE)
{
CHECK_EQ(batch_size, 1)<< "Only support batch size of 1 for FIT_SMALL_SIZE.";
}
}
// 读取一个数据,并读取数据的shape,初始化top的shape和prefetch的shape(比如数据大小为300x300)
// AnnotatedDatum包含了数据和标注(标注包含了label和bounding box)
// Read a data point, and use it to initialize the top blob.
AnnotatedDatum& anno_datum = *(reader_.full().peek()); // reader_中读到的数据就是输入的数据(包括图像数据和boundingbox坐标)
// Use data_transformer to infer the expected blob shape from anno_datum.
vector<int> top_shape =this->data_transformer_->InferBlobShape(anno_datum.datum());
this->transformed_data_.Reshape(top_shape);
// Reshape top[0] and prefetch_data according to the batch_size.
top_shape[0] = batch_size;
top[0]->Reshape(top_shape);
// 预读线程中的图像数据
for (int i = 0; i < this->PREFETCH_COUNT; ++i)
{
this->prefetch_[i].data_.Reshape(top_shape);
}
LOG(INFO) << "output data size: " << top[0]->num() << ","<< top[0]->channels() << "," << top[0]->height() << ","<< top[0]->width();
// label
if (this->output_labels_)
{
// 生成数据的时候是有类型的 anno_datum.set_type(AnnotatedDatum_AnnotationType_BBOX);
has_anno_type_ = anno_datum.has_type() || anno_data_param.has_anno_type();
vector<int> label_shape(4, 1);
if (has_anno_type_)
{
anno_type_ = anno_datum.type();
if (anno_data_param.has_anno_type())
{
// If anno_type is provided in AnnotatedDataParameter, replace
// the type stored in each individual AnnotatedDatum.
LOG(WARNING) << "type stored in AnnotatedDatum is shadowed.";
anno_type_ = anno_data_param.anno_type();
}
// Infer the label shape from anno_datum.AnnotationGroup().
int num_bboxes = 0;
// 读取该图像的所有box数量
if (anno_type_ == AnnotatedDatum_AnnotationType_BBOX)
{
// Since the number of bboxes can be different for each image,
// we store the bbox information in a specific format. In specific:
// All bboxes are stored in one spatial plane (num and channels are 1)
// And each row contains one and only one box in the following format:
// [item_id, group_label, instance_id, xmin, ymin, xmax, ymax, diff]
// Note: Refer to caffe.proto for details about group_label and
// instance_id.
for (int g = 0; g < anno_datum.annotation_group_size(); ++g) {
num_bboxes += anno_datum.annotation_group(g).annotation_size();
}
label_shape[0] = 1;
label_shape[1] = 1;
// BasePrefetchingDataLayer<Dtype>::LayerSetUp() requires to call
// cpu_data and gpu_data for consistent prefetch thread. Thus we make
// sure there is at least one bbox.
label_shape[2] = std::max(num_bboxes, 1);
label_shape[3] = 8;
}
else
{
LOG(FATAL) << "Unknown annotation type.";
}
}
else
{
label_shape[0] = batch_size;
}
top[1]->Reshape(label_shape);
// 预读线程中的label数据
for (int i = 0; i < this->PREFETCH_COUNT; ++i)
{
this->prefetch_[i].label_.Reshape(label_shape);
}
}
}
// This function is called on prefetch thread
template<typename Dtype>
void AnnotatedDataLayer<Dtype>::load_batch(Batch<Dtype>* batch)
{
CPUTimer batch_timer;
batch_timer.Start();
double read_time = 0;
double trans_time = 0;
CPUTimer timer;
CHECK(batch->data_.count());
CHECK(this->transformed_data_.count());
// Reshape according to the first anno_datum of each batch
// on single input batches allows for inputs of varying dimension.
const int batch_size = this->layer_param_.data_param().batch_size();
const AnnotatedDataParameter& anno_data_param =this->layer_param_.annotated_data_param();
const TransformationParameter& transform_param =this->layer_param_.transform_param();
// 初始化transformed_data_和 batch->data_的大小
AnnotatedDatum& anno_datum = *(reader_.full().peek());
vector<int> top_shape =this->data_transformer_->InferBlobShape(anno_datum.datum());// 3x300x300
this->transformed_data_.Reshape(top_shape); // transformed_data_存储一幅图像,对于SSD300,transformed_data_大小为:[1,3,300,300]
top_shape[0] = batch_size;
batch->data_.Reshape(top_shape); // batch->data_存储batchsize个图像,对于SSD300,batch->data_大小为[batchsize,3,300,300]
Dtype* top_data = batch->data_.mutable_cpu_data();
Dtype* top_label = NULL; // suppress warnings about uninitialized variables
if (this->output_labels_ && !has_anno_type_)
{
top_label = batch->label_.mutable_cpu_data();
}
// Store transformed annotation.
map<int, vector<AnnotationGroup> > all_anno; // batchsize中每一幅图像以及对应的标注
int num_bboxes = 0;
for (int item_id = 0; item_id < batch_size; ++item_id)
{
timer.Start();
// 获取一幅图像,并做相应的预处理(比如加入扰动)
AnnotatedDatum& anno_datum = *(reader_.full().pop("Waiting for data"));
read_time += timer.MicroSeconds();
timer.Start();
AnnotatedDatum distort_datum;
AnnotatedDatum* expand_datum = NULL;
if (transform_param.has_distort_param())
{
distort_datum.CopyFrom(anno_datum);
this->data_transformer_->DistortImage(anno_datum.datum(),
distort_datum.mutable_datum());
if (transform_param.has_expand_param())
{
expand_datum = new AnnotatedDatum();
this->data_transformer_->ExpandImage(distort_datum, expand_datum);
}
else
{
expand_datum = &distort_datum;
}
}
else
{
if (transform_param.has_expand_param())
{
expand_datum = new AnnotatedDatum();
this->data_transformer_->ExpandImage(anno_datum, expand_datum);
}
else
{
expand_datum = &anno_datum;
}
}
AnnotatedDatum* sampled_datum = NULL;
bool has_sampled = false;
if (batch_samplers_.size() > 0)
{
/* 1. 首先进行数据增强(对应论文2.2 Training部分的Data augmentation)
* 对于batchsize中的每一幅图像,为每个采样器(batch_sampler)生成max_sample个boundingbox(候选框)
* 每个采样器生成的boundingbox与目标的IOU=0.1,0.3,0.5,0.7,0.9,这个与论文的描述是一致的
* 示例:
batch_sampler
{
sampler
{
min_scale: 0.3
max_scale: 1.0
min_aspect_ratio: 0.5
max_aspect_ratio: 2.0
}
sample_constraint
{
min_jaccard_overlap: 0.7
}
max_sample: 1
max_trials: 50
}
* 对于该采样器,随机生成的满足条件的boundingbox与图像中任一目标的IOU>0.7
* 注意:
* 1. 生成的boundingbox坐标是归一化的坐标,这样不受resize的影响,目标检测的回归都是采用的这种形式(比如MTCNN)
* 2. 随机生成boundingbox的时候,根据每个batch_sampler的参数:尺度,宽高比,每个采样器最多尝试max_trials次
*
*/
vector<NormalizedBBox> sampled_bboxes;// 生成的是归一化的坐标
GenerateBatchSamples(*expand_datum, batch_samplers_, &sampled_bboxes);
/*2. 从生成的所有bounding box中随机挑选一个bounding box
* 裁剪出该bounding box对应的图像(大小就是sampled_bboxes[rand_idx]在原图中的大小)并计算该bounding box中所有目标的坐标以及类别
* 注意:
* 1. bounding box中目标的坐标=(原图中ground truth的坐标-该bounding box的坐标)/(bounding box的边长)
* 这里groundtruth与boundingbox的坐标都相对于原图,在mtcnn中也是采用了该计算方式
*
*/
if (sampled_bboxes.size() > 0)
{
int rand_idx = caffe_rng_rand() % sampled_bboxes.size();
sampled_datum = new AnnotatedDatum();
this->data_transformer_->CropImage(*expand_datum,sampled_bboxes[rand_idx],sampled_datum);
has_sampled = true;
}
else
{
sampled_datum = expand_datum;
}
}
else
{
sampled_datum = expand_datum;
}
CHECK(sampled_datum != NULL);
timer.Start();
vector<int> shape =this->data_transformer_->InferBlobShape(sampled_datum->datum());
if (transform_param.has_resize_param())
{
// 不执行该部分
if (transform_param.resize_param().resize_mode() ==ResizeParameter_Resize_mode_FIT_SMALL_SIZE)
{
this->transformed_data_.Reshape(shape);
batch->data_.Reshape(shape);
top_data = batch->data_.mutable_cpu_data();
}
else
{
CHECK(std::equal(top_shape.begin() + 1, top_shape.begin() + 4,shape.begin() + 1));
}
}
else
{
CHECK(std::equal(top_shape.begin() + 1, top_shape.begin() + 4,
shape.begin() + 1));
}
// Apply data transformations (mirror, scale, crop...)
int offset = batch->data_.offset(item_id);
this->transformed_data_.set_cpu_data(top_data + offset);
vector<AnnotationGroup> transformed_anno_vec;
if (this->output_labels_)
{
if (has_anno_type_)
{
// Make sure all data have same annotation type.
CHECK(sampled_datum->has_type()) << "Some datum misses AnnotationType.";
if (anno_data_param.has_anno_type())
{
sampled_datum->set_type(anno_type_);
}
else
{
CHECK_EQ(anno_type_, sampled_datum->type()) <<
"Different AnnotationType.";
}
// Transform datum and annotation_group at the same time
transformed_anno_vec.clear();
// AnnotatedDatum,Blob<float>,vector<AnnotationGroup>
/* 3. 将crop出来的AnnotatedDatum转换为数据部分和标注部分
* 数据部分会resize到数据层设置的大小(比如300x300)并保存到top[0]中
* 标注是所有目标在图像中的坐标
*
* 注意:
* 1. 这里的图像并不一定是原始crop的图像,如果transform_param有crop_size这个参数,原来crop出来的图像会再次crop的
* 2. 由于这里对crop出来的图像进行了一次resize,所以如果生成lmdb的时候,进行resize会导致数据层对原图进行两次resize,
* 这样有可能会影响到目标的宽高比,所以在SFD(Single Shot Scale-invariant Face Detector)中,对此处做了一点改进,即在第一步
* 生成boundingbox的时候,保证每个boundingbox都是正方形,这样resize到300x300的时候就不会改变目标的宽高比
*/
this->data_transformer_->Transform(*sampled_datum,&(this->transformed_data_),&transformed_anno_vec);
if (anno_type_ == AnnotatedDatum_AnnotationType_BBOX)
{
// Count the number of bboxes.
// 计算该随机生成的bounding box中有多少目标
for (int g = 0; g < transformed_anno_vec.size(); ++g)
{
num_bboxes += transformed_anno_vec[g].annotation_size();
}
}
else
{
LOG(FATAL) << "Unknown annotation type.";
}
// batchsize中第item_id个图像的标注
all_anno[item_id] = transformed_anno_vec;
}
else
{
this->data_transformer_->Transform(sampled_datum->datum(),&(this->transformed_data_));
// Otherwise, store the label from datum.
CHECK(sampled_datum->datum().has_label()) << "Cannot find any label.";
top_label[item_id] = sampled_datum->datum().label();
}
}
else
{
this->data_transformer_->Transform(sampled_datum->datum(),&(this->transformed_data_));
}
// clear memory
if (has_sampled) {
delete sampled_datum;
}
if (transform_param.has_expand_param()) {
delete expand_datum;
}
trans_time += timer.MicroSeconds();
// 将读过的数据再放回去
reader_.free().push(const_cast<AnnotatedDatum*>(&anno_datum));
}
// Store "rich" annotation if needed.
/*4. 最后将标注信息保存到top[1]中,top[1]的shape:[1,1,numberOfBoxes,8]
*每一行格式:[item_id, group_label, instance_id, xmin, ymin, xmax, ymax, diff]
*这个8维向量表示的含义:batchsize个图像中的第item_id幅图像中的第group_label个类别下的第instance_id个box的坐标为[xmin, ymin, xmax, ymax]
*
*/
if (this->output_labels_ && has_anno_type_)
{
vector<int> label_shape(4);
if (anno_type_ == AnnotatedDatum_AnnotationType_BBOX)
{
label_shape[0] = 1;
label_shape[1] = 1;
label_shape[3] = 8;
if (num_bboxes == 0)
{
// Store all -1 in the label.
label_shape[2] = 1;
batch->label_.Reshape(label_shape);
caffe_set<Dtype>(8, -1, batch->label_.mutable_cpu_data());
}
else
{
// num_bboxes就是前面crop出来的所有图像中所有目标的数量
label_shape[2] = num_bboxes;
batch->label_.Reshape(label_shape);
top_label = batch->label_.mutable_cpu_data();
int idx = 0;
// 遍历bachsizes中每一幅图像的label信息
for (int item_id = 0; item_id < batch_size; ++item_id)
{
// 第ite_id幅图像的label信息
const vector<AnnotationGroup>& anno_vec = all_anno[item_id];
for (int g = 0; g < anno_vec.size(); ++g)
{
const AnnotationGroup& anno_group = anno_vec[g];
for (int a = 0; a < anno_group.annotation_size(); ++a)
{
const Annotation& anno = anno_group.annotation(a);
const NormalizedBBox& bbox = anno.bbox();
top_label[idx++] = item_id;
top_label[idx++] = anno_group.group_label();
top_label[idx++] = anno.instance_id();
top_label[idx++] = bbox.xmin();
top_label[idx++] = bbox.ymin();
top_label[idx++] = bbox.xmax();
top_label[idx++] = bbox.ymax();
top_label[idx++] = bbox.difficult();
}
}
}
}
}
else
{
LOG(FATAL) << "Unknown annotation type.";
}
}
timer.Stop();
batch_timer.Stop();
DLOG(INFO) << "Prefetch batch: " << batch_timer.MilliSeconds() << " ms.";
DLOG(INFO) << " Read time: " << read_time / 1000 << " ms.";
DLOG(INFO) << "Transform time: " << trans_time / 1000 << " ms.";
}
INSTANTIATE_CLASS(AnnotatedDataLayer);
REGISTER_LAYER_CLASS(AnnotatedData);
} // namespace caffe
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
- 87
- 88
- 89
- 90
- 91
- 92
- 93
- 94
- 95
- 96
- 97
- 98
- 99
- 100
- 101
- 102
- 103
- 104
- 105
- 106
- 107
- 108
- 109
- 110
- 111
- 112
- 113
- 114
- 115
- 116
- 117
- 118
- 119
- 120
- 121
- 122
- 123
- 124
- 125
- 126
- 127
- 128
- 129
- 130
- 131
- 132
- 133
- 134
- 135
- 136
- 137
- 138
- 139
- 140
- 141
- 142
- 143
- 144
- 145
- 146
- 147
- 148
- 149
- 150
- 151
- 152
- 153
- 154
- 155
- 156
- 157
- 158
- 159
- 160
- 161
- 162
- 163
- 164
- 165
- 166
- 167
- 168
- 169
- 170
- 171
- 172
- 173
- 174
- 175
- 176
- 177
- 178
- 179
- 180
- 181
- 182
- 183
- 184
- 185
- 186
- 187
- 188
- 189
- 190
- 191
- 192
- 193
- 194
- 195
- 196
- 197
- 198
- 199
- 200
- 201
- 202
- 203
- 204
- 205
- 206
- 207
- 208
- 209
- 210
- 211
- 212
- 213
- 214
- 215
- 216
- 217
- 218
- 219
- 220
- 221
- 222
- 223
- 224
- 225
- 226
- 227
- 228
- 229
- 230
- 231
- 232
- 233
- 234
- 235
- 236
- 237
- 238
- 239
- 240
- 241
- 242
- 243
- 244
- 245
- 246
- 247
- 248
- 249
- 250
- 251
- 252
- 253
- 254
- 255
- 256
- 257
- 258
- 259
- 260
- 261
- 262
- 263
- 264
- 265
- 266
- 267
- 268
- 269
- 270
- 271
- 272
- 273
- 274
- 275
- 276
- 277
- 278
- 279
- 280
- 281
- 282
- 283
- 284
- 285
- 286
- 287
- 288
- 289
- 290
- 291
- 292
- 293
- 294
- 295
- 296
- 297
- 298
- 299
- 300
- 301
- 302
- 303
- 304
- 305
- 306
- 307
- 308
- 309
- 310
- 311
- 312
- 313
- 314
- 315
- 316
- 317
- 318
- 319
- 320
- 321
- 322
- 323
- 324
- 325
- 326
- 327
- 328
- 329
- 330
- 331
- 332
- 333
- 334
- 335
- 336
- 337
- 338
- 339
- 340
- 341
- 342
- 343
- 344
- 345
- 346
- 347
- 348
- 349
- 350
- 351
- 352
- 353
- 354
- 355
- 356
- 357
- 358
- 359
- 360
- 361
- 362
- 363
- 364
- 365
- 366
- 367
- 368
- 369
- 370
- 371
- 372
- 373
- 374
- 375
- 376
- 377
- 378
- 379
- 380
- 381
- 382
- 383
- 384
- 385
- 386
- 387
- 388
- 389
- 390
- 391
- 392
- 393
- 394
- 395
- 396
- 397
- 398
- 399
- 400
- 401
- 402
- 403
- 404
- 405
- 406
- 407
- 408
- 409
- 410
- 411
- 412
- 413
- 414
- 415
- 416
- 417
- 418
- 419
- 420
- 421
- 422
- 423
- 424
- 425
- 426
- 427
- 428
- 429
- 430
- 431
- 432
- 433
- 434
- 435
数据层中有几个比较重要的函数GenerateBatchSamples(),this->data_transformer_->CropImage(),this->data_transformer_->Transform(),下面对他们详细解读一下
GenerateBatchSamples
void GenerateBatchSamples(const AnnotatedDatum& anno_datum,
const vector<BatchSampler>& batch_samplers,
vector<NormalizedBBox>* sampled_bboxes)
{
sampled_bboxes->clear();
// 获取groundtruth box
vector<NormalizedBBox> object_bboxes;
GroupObjectBBoxes(anno_datum, &object_bboxes);
// 对于每个采样器生成多个box
for (int i = 0; i < batch_samplers.size(); ++i)
{
// Use original image as the source for sampling.
if (batch_samplers[i].use_original_image())
{
NormalizedBBox unit_bbox;
unit_bbox.set_xmin(0);
unit_bbox.set_ymin(0);
unit_bbox.set_xmax(1);
unit_bbox.set_ymax(1);
GenerateSamples(unit_bbox, // 单位box
object_bboxes,// ground truth box
batch_samplers[i], // 采样器
sampled_bboxes);
}
}
}
void GenerateSamples(const NormalizedBBox& source_bbox, // 单位box
const vector<NormalizedBBox>& object_bboxes, // object_bboxes就是该图像中所有的ground truth boxes
const BatchSampler& batch_sampler, // 采样器
vector<NormalizedBBox>* sampled_bboxes)
{
int found = 0;
// 每个采样器batch_sampler都要尝试max_trials次
for (int i = 0; i < batch_sampler.max_trials(); ++i)
{
// 每个batch_sampler生成的boundingbox个数大于等于max_sample了,就跳出
if (batch_sampler.has_max_sample() && found >= batch_sampler.max_sample())
{
break;
}
// Generate sampled_bbox in the normalized space [0, 1].
// 随机生成一个box
NormalizedBBox sampled_bbox;
SampleBBox(batch_sampler.sampler(), &sampled_bbox);
// Transform the sampled_bbox w.r.t. source_bbox.
// 转换为在单位box中的坐标,由于都是单位box,所以转换后还是自己
LocateBBox(source_bbox, sampled_bbox, &sampled_bbox);
// Determine if the sampled bbox is positive or negative by the constraint.
// 所有的ground truth 与生成的boundingbox计算IOU,是否满足条件
if (SatisfySampleConstraint(sampled_bbox, object_bboxes,batch_sampler.sample_constraint()))
{
++found;
sampled_bboxes->push_back(sampled_bbox);
}
}
}
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
DataTransformer::CropImage
template<typename Dtype>
void DataTransformer<Dtype>::CropImage(const AnnotatedDatum& anno_datum,
const NormalizedBBox& bbox,
AnnotatedDatum* cropped_anno_datum)
{
// 首先crop数据:将bbox映射到原图的像素坐标并crop出该区域对应的图像,然后将crop出来的图像保存到cropped_anno_datum
CropImage(anno_datum.datum(), bbox, cropped_anno_datum->mutable_datum());
cropped_anno_datum->set_type(anno_datum.type());
// 根据crop_bbox 转换 annotation
// cropped_anno_datum保持的就是该图像中每个类别下所有bbox与该crop_bbox的偏移(用(ground truth-crop_bbox)/crop_bbox)
// Transform the annotation according to crop_bbox.
const bool do_resize = false;
const bool do_mirror = false;
NormalizedBBox crop_bbox;
ClipBBox(bbox, &crop_bbox); // 边界判断
TransformAnnotation(anno_datum, do_resize, crop_bbox, do_mirror,cropped_anno_datum->mutable_annotation_group());
}
template<typename Dtype>
void DataTransformer<Dtype>::TransformAnnotation(
const AnnotatedDatum& anno_datum, const bool do_resize,
const NormalizedBBox& crop_bbox, const bool do_mirror,
RepeatedPtrField<AnnotationGroup>* transformed_anno_group_all)
{
const int img_height = anno_datum.datum().height();
const int img_width = anno_datum.datum().width();
if (anno_datum.type() == AnnotatedDatum_AnnotationType_BBOX)
{
// Go through each AnnotationGroup.
// 计算每个类别下所有bbox与该crop_bbox的偏移,其实就算计算随机生成的box中所有目标的坐标
for (int g = 0; g < anno_datum.annotation_group_size(); ++g)
{
const AnnotationGroup& anno_group = anno_datum.annotation_group(g);
AnnotationGroup transformed_anno_group;
bool has_valid_annotation = false;
// 每个类别的所有Annotation
for (int a = 0; a < anno_group.annotation_size(); ++a)
{
const Annotation& anno = anno_group.annotation(a);
const NormalizedBBox& bbox = anno.bbox();
// Adjust bounding box annotation.
NormalizedBBox resize_bbox = bbox;
// 这里do_resize和do_mirror都是false
if (do_resize && param_.has_resize_param())
{
CHECK_GT(img_height, 0);
CHECK_GT(img_width, 0);
UpdateBBoxByResizePolicy(param_.resize_param(), img_width, img_height,&resize_bbox);
}
if (param_.has_emit_constraint() &&!MeetEmitConstraint(crop_bbox, resize_bbox,param_.emit_constraint()))
{
continue;
}
// ProjectBBox计算ground truth与随机生成的bbox的偏移(只计算有交集的)
NormalizedBBox proj_bbox; // proj_bbox 就是偏移,就是做回归用的
if (ProjectBBox(crop_bbox, resize_bbox, &proj_bbox))
{
has_valid_annotation = true;
Annotation* transformed_anno =transformed_anno_group.add_annotation();
transformed_anno->set_instance_id(anno.instance_id());
NormalizedBBox* transformed_bbox = transformed_anno->mutable_bbox();
transformed_bbox->CopyFrom(proj_bbox);
if (do_mirror)
{
Dtype temp = transformed_bbox->xmin();
transformed_bbox->set_xmin(1 - transformed_bbox->xmax());
transformed_bbox->set_xmax(1 - temp);
}
if (do_resize && param_.has_resize_param())
{
ExtrapolateBBox(param_.resize_param(), img_height, img_width,crop_bbox, transformed_bbox);
}
}
}
// Save for output.
if (has_valid_annotation)
{
// 遍历完该类别下所有ground truth,设置label
transformed_anno_group.set_group_label(anno_group.group_label());
transformed_anno_group_all->Add()->CopyFrom(transformed_anno_group);
}
}
}
else
{
LOG(FATAL) << "Unknown annotation type.";
}
}
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
- 87
- 88
- 89
- 90
- 91
- 92
- 93
DataTransformer::Transform
template<typename Dtype>
void DataTransformer<Dtype>::Transform(
const AnnotatedDatum& anno_datum, Blob<Dtype>* transformed_blob,
vector<AnnotationGroup>* transformed_anno_vec)
{
bool do_mirror;
Transform(anno_datum, transformed_blob, transformed_anno_vec, &do_mirror);
}
template<typename Dtype>
void DataTransformer<Dtype>::Transform(
const AnnotatedDatum& anno_datum, Blob<Dtype>* transformed_blob,
vector<AnnotationGroup>* transformed_anno_vec, bool* do_mirror) {
RepeatedPtrField<AnnotationGroup> transformed_anno_group_all;
Transform(anno_datum, transformed_blob, &transformed_anno_group_all,
do_mirror);
for (int g = 0; g < transformed_anno_group_all.size(); ++g) {
transformed_anno_vec->push_back(transformed_anno_group_all.Get(g));
}
}
template<typename Dtype>
void DataTransformer<Dtype>::Transform(
const AnnotatedDatum& anno_datum, Blob<Dtype>* transformed_blob,
RepeatedPtrField<AnnotationGroup>* transformed_anno_group_all,
bool* do_mirror)
{
// Transform datum.
/* 转换数据
* 如果DataTransformer参数中没有crop_size,则crop_box还是原图大小(归一化大小,这里就是(0,0,1,1))
* 如果数据层有resize参数,就会缩放
*
*/
const Datum& datum = anno_datum.datum();
NormalizedBBox crop_bbox;
Transform(datum, transformed_blob, &crop_bbox, do_mirror);
// Transform annotation.
/* 转换标注
* 计算最后cropped出来的图像中所有目标的坐标(transform_param如果有crop_size,会crop出一块区域)
*/
const bool do_resize = true;
TransformAnnotation(anno_datum, do_resize, crop_bbox, *do_mirror,
transformed_anno_group_all);
}
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
数据层的源码大概就是这样,大家有什么疑问的,可以留言一起讨论。
2018-4-29 22:44:01
Last updated: 2018-5-1 10:53:20
非常感谢您的阅读,如果您觉得这篇文章对您有帮助,欢迎扫码进行赞赏。