caffe修改源码实现多label

在我的上一篇博客中caffe实现多标签输入中,介绍了用把图像和label分来,各自做成lmdb,最后把label的lmdb用slice层分开,这篇博客介绍另一种修改源码的方法实现多label,比其他博客改动源码最少

简介

我们都知道ImageDataLayer是直接读取原图进行分类,它的label是单label,文件格式如下

train.txt示例
001.jpg 1
002.jpg 2
003.jpg 3

layer {  
  name: "demo" type: "ImageData" top: "data" top: "label" include { phase: TRAIN }  
  transform_param {  
    scale: 0.00390625 mean_value: 128 }  
  image_data_param {  
    source: "your path/train.txt" root_folder: "your image data path" new_height: xxx new_width: xxx batch_size: 32 shuffle: true #每个epoch都会进行shuffle }  
}  
   
   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26

修改代码

由于ImageataLayer的限制,我们只能在train.txt中放置单label,现在我们来修改ImageataLayer的代码来实现多label
主要涉及三个文件

  • caffe/src/caffe/proto/caffe.proto
  • caffe/include/caffe/layers/image_data_layer.hpp
  • caffe/src/caffe/layers/image_data_layer.cpp

定位到caffe/src/caffe/proto/caffe.proto中message ImageDataParameter

// 添加一个参数
// Specify the label dim. default 1.
// 有几种label,比如性别、年龄两种label,在网络结构里就把此参数设置为2
optional uint32 label_dim = IDNumber [default = 1]; 
// IDNumber是和其它参数不冲突的ID数字
   
   
  • 1
  • 2
  • 3
  • 4
  • 5

定位到caffe/include/caffe/layers/image_data_layer.hpp

// 修改vector<std::pair<std::string, int> > lines_;
// string对应那个train.txt中的图片名称,in对应label,我们把int改为int*,实现多label
vector<std::pair<std::string, int *> > lines_;
   
   
  • 1
  • 2
  • 3

定位到caffe/src/caffe/layers/image_data_layer.cpp

// DataLayerSetUp函数
// 原本的加载图片名称和label的代码
  std::ifstream infile(source.c_str());
  string line;
  size_t pos;
  int label;
  while (std::getline(infile, line)) {
    pos = line.find_last_of(' ');
    label = atoi(line.substr(pos + 1).c_str());
    lines_.push_back(std::make_pair(line.substr(0, pos), label));
  }
// 修改为这样
std::ifstream infile(source.c_str());
  string filename;
  // 获取label的种类
  int label_dim = this->layer_param_.image_data_param().label_dim();
  // 注意这里默认每个label直接以空格隔开,每个图片名称及其label占一行,如果你的格式不同,可自行修改读取方式
  while (infile >> filename) {
    int* labels = new int[label_dim];
    for(int i = 0;i < label_dim;++i){
        infile >> labels[i];
    }
    lines_.push_back(std::make_pair(filename, labels));
  }
// 原本的输出label
  vector<int> label_shape(1, batch_size);
  top[1]->Reshape(label_shape);
  for (int i = 0; i < this->prefetch_.size(); ++i) {
    this->prefetch_[i]->label_.Reshape(label_shape);
  }
// 修改为这样
  vector<int> label_shape(2);
  label_shape[0] = batch_size;
  label_shape[1] = label_dim;
  top[1]->Reshape(label_shape); // label的输出shape batch_size*label_dim
  for (int i = 0; i < this->PREFETCH_COUNT; ++i) {
    this->prefetch_[i].label_.Reshape(label_shape);
  }
// 注意:caffe最新版本prefetch_的结构由之前的Batch<Dtype> prefetch_[PREFETCH_COUNT];
// 改为 vector<shared_ptr<Batch<Dtype> > > prefetch_; 由对象数组改为了存放shared指针的vector。
// 所以此处的this->PREFETCH_COUNT改为this->prefetch_.size(); 
// 此处的this->prefetch_[i].label_.Reshape(label_shape);
// 改为this->prefetch_[i]->label_.Reshape(label_shape);把.改成指针的->
// load_batch函数
// 在函数一开始先获取下label_dim参数
int label_dim = this->layer_param_.image_data_param().label_dim();
// 原本的预取label
prefetch_label[item_id] = lines_[lines_id_].second;
// 修改为这样
for(int i = 0;i < label_dim;++i){
    // lines_[lines_id_].second就是最开始改为的int*,多label
    prefetch_label[item_id * label_dim + i] = lines_[lines_id_].second[i];
}
   
   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53

最后进行make

示例

train.txt
001.jpg 1 3 2
002.jpg 2 4 7
003.jpg 3 0 9
   
   
  • 1
  • 2
  • 3
  • 4
# trainval.prototxt
layer {
  name: ”data” type: “ImageData” top: “data” top: “label” include { phase: TRAIN }
  transform_param {
    mirror: true mean_value: 128 mean_value: 128 mean_value: 128 }
  image_data_param {
    mirror: true source: ”your path/train.txt” root_folder: “your image data path” new_height: xxx new_width: xxx batch_size: 32 shuffle: true #每个epoch都会进行shuffle label_dim: 3 }
}
layer {
  name: ”slice” type: “Slice” bottom: “label” top: “label_1” top: “label_2” top: “label_3” slice_param { axis: 1 slice_point:1 slice_point:2 }
}
……




后续代码参考我另一篇caffe实现多label输入的代码
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41

这里写图片描述

总结

对比其他修改源码的博客,本篇对源码的改动最少,而且兼容原来的版本,最重要的是使用起来最方便,不用生成lmdb之类




MathJax.Hub.Config({
showMathMenu: false,
preferredFont: "STIX",
"HTML-CSS": {availableFonts: ["STIX","TeX"]}
});



@font-face {
font-family: ‘MathJax_Main’;
src: url(‘ http://static.blog.csdn.net/mdeditor/public/res/bower-libs/MathJax/fonts/HTML-CSS/TeX/eot/MathJax_Math-Italic.eot‘); /* IE9 Compat Modes */
src: url(‘ http://static.blog.csdn.net/mdeditor/public/res/bower-libs/MathJax/fonts/HTML-CSS/TeX/eot/MathJax_Math-Italic.eot?iefix‘) format(‘eot’),
url(‘ http://static.blog.csdn.net/mdeditor/public/res/bower-libs/MathJax/fonts/HTML-CSS/TeX/woff/MathJax_Math-Italic.woff‘) format(‘woff’),
url(‘ http://static.blog.csdn.net/mdeditor/public/res/bower-libs/MathJax/fonts/HTML-CSS/TeX/otf/MathJax_Math-Italic.otf‘) format(‘opentype’),
url(‘ http://static.blog.csdn.net/mdeditor/public/res/bower-libs/MathJax/fonts/HTML-CSS/TeX/svg/MathJax_Math-Italic.svg#MathJax_Math-Italic‘) format(‘svg’);
}
@font-face {
font-family: ‘MathJax_Main’;
src: url(‘ http://static.blog.csdn.net/mdeditor/public/res/bower-libs/MathJax/fonts/HTML-CSS/TeX/eot/MathJax_Main-Regular.eot‘); /* IE9 Compat Modes */
src: url(‘ http://static.blog.csdn.net/mdeditor/public/res/bower-libs/MathJax/fonts/HTML-CSS/TeX/eot/MathJax_Main-Regular.eot?iefix‘) format(‘eot’),
url(‘ http://static.blog.csdn.net/mdeditor/public/res/bower-libs/MathJax/fonts/HTML-CSS/TeX/woff/MathJax_Main-Regular.woff‘) format(‘woff’),
url(‘ http://static.blog.csdn.net/mdeditor/public/res/bower-libs/MathJax/fonts/HTML-CSS/TeX/otf/MathJax_Main-Regular.otf‘) format(‘opentype’),
url(‘ http://static.blog.csdn.net/mdeditor/public/res/bower-libs/MathJax/fonts/HTML-CSS/TeX/svg/MathJax_Main-Regular.svg#MathJax_Main-Regular‘) format(‘svg’);
}

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值