caffe增加自己的layer实战(下)--caffe学习(12)

接上篇 caffe增加自己的layer实战(中)–caffe学习(11)
先放出完整的修改后的video_data_layers.cpp

#include <fstream>
#include <iostream>
#include <string>
#include <utility>
#include <vector>

#include "caffe/data_layers.hpp"
#include "caffe/layer.hpp"
#include "caffe/proto/caffe.pb.h"
#include "caffe/util/io.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/util/rng.hpp"

#ifdef USE_MPI
#include "mpi.h"
#include <boost/filesystem.hpp>
using namespace boost::filesystem;
#endif

namespace caffe{
template <typename Dtype>
VideoDataLayer<Dtype>:: ~VideoDataLayer<Dtype>(){
    this->JoinPrefetchThread();
}

template <typename Dtype>
void VideoDataLayer<Dtype>:: DataLayerSetUp(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top){
    const int new_height  = this->layer_param_.video_data_param().new_height();
    const int new_width  = this->layer_param_.video_data_param().new_width();
    const int new_length  = this->layer_param_.video_data_param().new_length();
    const int num_segments = this->layer_param_.video_data_param().num_segments();
    const string& source = this->layer_param_.video_data_param().source();

    LOG(INFO) << "Opening file: " << source;
    std:: ifstream infile(source.c_str());
    string filename;
    int label;
    int length;
    while (infile >> filename >> length >> label){
        lines_.push_back(std::make_pair(filename,label));
        lines_duration_.push_back(length);
    }
    if (this->layer_param_.video_data_param().shuffle()){
        const unsigned int prefectch_rng_seed = caffe_rng_rand();
        prefetch_rng_1_.reset(new Caffe::RNG(prefectch_rng_seed));
        prefetch_rng_2_.reset(new Caffe::RNG(prefectch_rng_seed));
        ShuffleVideos();
    }

    LOG(INFO) << "A total of " << lines_.size() << " videos.";
    lines_id_ = 0;

    //check name patter
    if (this->layer_param_.video_data_param().name_pattern() == ""){
        if (this->layer_param_.video_data_param().modality() == VideoDataParameter_Modality_RGB){
            name_pattern_ = "image_%04d.jpg";
        }else if (this->layer_param_.video_data_param().modality() == VideoDataParameter_Modality_FLOW){
            name_pattern_ = "flow_%c_%04d.jpg";
        }
    }else{
        name_pattern_ = this->layer_param_.video_data_param().name_pattern();
    }

    Datum datum;
    const unsigned int frame_prefectch_rng_seed = caffe_rng_rand();
    frame_prefetch_rng_.reset(new Caffe::RNG(frame_prefectch_rng_seed));
    int average_duration = (int) lines_duration_[lines_id_]/num_segments;
    vector<int> offsets;
    for (int i = 0; i < num_segments; ++i){
        caffe::rng_t* frame_rng = static_cast<caffe::rng_t*>(frame_prefetch_rng_->generator());
        int offset = (*frame_rng)() % (average_duration - new_length + 1);
        offsets.push_back(offset+i*average_duration);
    }
    if (this->layer_param_.video_data_param().modality() == VideoDataParameter_Modality_FLOW)
        CHECK(ReadSegmentFlowToDatum(lines_[lines_id_].first, lines_[lines_id_].second,
                                     offsets, new_height, new_width, new_length, &datum, name_pattern_.c_str()));
    else
        CHECK(ReadSegmentRGBToDatum(lines_[lines_id_].first, lines_[lines_id_].second,
                                    offsets, new_height, new_width, new_length, &datum, true, name_pattern_.c_str()));
    const int crop_size = this->layer_param_.transform_param().crop_size();
    const int batch_size = this->layer_param_.video_data_param().batch_size();
    if (crop_size > 0){
        top[0]->Reshape(batch_size, datum.channels(), crop_size, crop_size);
        this->prefetch_data_.Reshape(batch_size, datum.channels(), crop_size, crop_size);
    } else {
        top[0]->Reshape(batch_size, datum.channels(), datum.height(), datum.width());
        this->prefetch_data_.Reshape(batch_size, datum.channels(), datum.height(), datum.width());
    }
    LOG(INFO) << "output data size: " << top[0]->num() << "," << top[0]->channels() << "," << top[0]->height() << "," << top[0]->width();

    top[1]->Reshape(batch_size, 1, 1, 1);
    this->prefetch_label_.Reshape(batch_size, 1, 1, 1);

    vector<int> top_shape = this->data_transformer_->InferBlobShape(datum);
    this->transformed_data_.Reshape(top_shape);
}

template <typename Dtype>
void VideoDataLayer<Dtype>::ShuffleVideos(){
    caffe::rng_t* prefetch_rng1 = static_cast<caffe::rng_t*>(prefetch_rng_1_->generator());
    caffe::rng_t* prefetch_rng2 = static_cast<caffe::rng_t*>(prefetch_rng_2_->generator());
    shuffle(lines_.begin(), lines_.end(), prefetch_rng1);
    shuffle(lines_duration_.begin(), lines_duration_.end(),prefetch_rng2);
}

template <typename Dtype>
void VideoDataLayer<Dtype>::InternalThreadEntry(){

    Datum datum;
    CHECK(this->prefetch_data_.count());
    Dtype* top_data = this->prefetch_data_.mutable_cpu_data();
    Dtype* top_label = this->prefetch_label_.mutable_cpu_data();
    VideoDataParameter video_data_param = this->layer_param_.video_data_param();
    const int batch_size = video_data_param.batch_size();
    const int new_height = video_data_param.new_height();
    const int new_width = video_data_param.new_width();
    const int new_length = video_data_param.new_length();
    const int num_segments = video_data_param.num_segments();
    const int lines_size = lines_.size();

    for (int item_id = 0; item_id < batch_size; ++item_id){
        CHECK_GT(lines_size, lines_id_);
        vector<int> offsets;
        int average_duration = (int) lines_duration_[lines_id_] / num_segments;
        for (int i = 0; i < num_segments; ++i){
            if (this->phase_==TRAIN){
                if (average_duration >= new_length){
                    caffe::rng_t* frame_rng = static_cast<caffe::rng_t*>(frame_prefetch_rng_->generator());
                    int offset = (*frame_rng)() % (average_duration - new_length + 1);
                    offsets.push_back(offset+i*average_duration);
                } else {
                    offsets.push_back(0);
                }
            } else{
                if (average_duration >= new_length)
                offsets.push_back(int((average_duration-new_length+1)/2 + i*average_duration));
                else
                offsets.push_back(0);
            }
        }
        if (this->layer_param_.video_data_param().modality() == VideoDataParameter_Modality_FLOW){
            if(!ReadSegmentFlowToDatum(lines_[lines_id_].first, lines_[lines_id_].second,
                                       offsets, new_height, new_width, new_length, &datum, name_pattern_.c_str())) {
                continue;
            }
        } else{
            if(!ReadSegmentRGBToDatum(lines_[lines_id_].first, lines_[lines_id_].second,
                                      offsets, new_height, new_width, new_length, &datum, true, name_pattern_.c_str())) {
                continue;
            }
        }

        int offset1 = this->prefetch_data_.offset(item_id);
                this->transformed_data_.set_cpu_data(top_data + offset1);
        this->data_transformer_->Transform(datum, &(this->transformed_data_));
        top_label[item_id] = lines_[lines_id_].second;
        //LOG()

        //next iteration
        lines_id_++;
        if (lines_id_ >= lines_size) {
            DLOG(INFO) << "Restarting data prefetching from start.";
            lines_id_ = 0;
            if(this->layer_param_.video_data_param().shuffle()){
                ShuffleVideos();
            }
        }
    }
}

INSTANTIATE_CLASS(VideoDataLayer);
REGISTER_LAYER_CLASS(VideoData);
}

接下来分析其中的三个函数:
DataLayerSetUp,ShuffleVideos,InternalThreadEntry

DataLayerSetUp函数:初始化数据层
ShuffleVideos函数:将视频顺序打乱
InternalThreadEntry函数:创建新线程

函数一开始

    const int new_height  = this->layer_param_.video_data_param().new_height();
    const int new_width  = this->layer_param_.video_data_param().new_width();
    const int new_length  = this->layer_param_.video_data_param().new_length();
    const int num_segments = this->layer_param_.video_data_param().num_segments();
    const string& source = this->layer_param_.video_data_param().source();

new_height和new_width参数在image中也有,后面的都是不同的:

new_length:一个视频最小片段包含几张图片文件,帧数
num_segments:将整个视频分割为几段
source:训练和test文件所在的txt文档位置

后面几行:

    LOG(INFO) << "Opening file: " << source;
    std:: ifstream infile(source.c_str());
    string filename;
    int label;
    int length;
    while (infile >> filename >> length >> label){
        lines_.push_back(std::make_pair(filename,label));
        lines_duration_.push_back(length);
    }
    if (this->layer_param_.video_data_param().shuffle()){
        const unsigned int prefectch_rng_seed = caffe_rng_rand();
        prefetch_rng_1_.reset(new Caffe::RNG(prefectch_rng_seed));
        prefetch_rng_2_.reset(new Caffe::RNG(prefectch_rng_seed));
        ShuffleVideos();
    }

    LOG(INFO) << "A total of " << lines_.size() << " videos.";
    lines_id_ = 0;

和image的基本保持一致,
这里最大的区别就是对头文件里面的几个参数的赋值

        const unsigned int prefectch_rng_seed = caffe_rng_rand();
        prefetch_rng_1_.reset(new Caffe::RNG(prefectch_rng_seed));
        prefetch_rng_2_.reset(new Caffe::RNG(prefectch_rng_seed));

然后就是对输入文件的组织顺序,文件名|视频帧数|类别

LOG(INFO) << "A total of " << lines_.size() << " videos.";

从这个可见参数lines_保存的是样本文件名
接下来自己定义输入参数名字匹配方法:

    //check name patter
    if (this->layer_param_.video_data_param().name_pattern() == ""){
        if (this->layer_param_.video_data_param().modality() == VideoDataParameter_Modality_RGB){
            name_pattern_ = "image_%04d.jpg";
        }else if (this->layer_param_.video_data_param().modality() == VideoDataParameter_Modality_FLOW){
            name_pattern_ = "flow_%c_%04d.jpg";
        }
    }else{
        name_pattern_ = this->layer_param_.video_data_param().name_pattern();
    }

不多说。继续分析
其实这里就是对自己在caffeprototxt里面定义的参数的解析,原来的image中还解析了参数randomly skip:

  // Check if we would need to randomly skip a few data points
  if (this->layer_param_.image_data_param().rand_skip()) {
    unsigned int skip = caffe_rng_rand() %
        this->layer_param_.image_data_param().rand_skip();
    LOG(INFO) << "Skipping first " << skip << " data points.";
    CHECK_GT(lines_.size(), skip) << "Not enough points to skip";
    lines_id_ = skip;
  }

我们这里没有此参数,因此不需要。

Datum datum;    //定义存储数据

干货来了,将图片数据读取到结构体中:

    //读取文件到datum中,ReadSegmentFlowToDatum等函数定义在io.cpp中
    if (this->layer_param_.video_data_param().modality() == VideoDataParameter_Modality_FLOW)
        CHECK(ReadSegmentFlowToDatum(lines_[lines_id_].first, lines_[lines_id_].second,
                                     offsets, new_height, new_width, new_length, &datum, name_pattern_.c_str()));
    else
        CHECK(ReadSegmentRGBToDatum(lines_[lines_id_].first, lines_[lines_id_].second,
                                    offsets, new_height, new_width, new_length, &datum, true, name_pattern_.c_str()));

以函数ReadSegmentRGBToDatum为例,定义在io.cpp中:

bool ReadSegmentRGBToDatum(const string& filename, const int label,
    const vector<int> offsets, const int height, const int width, const int length, Datum* datum, bool is_color,
    const char* name_pattern ){
    cv::Mat cv_img;
    string* datum_string;
    char tmp[30];
    int cv_read_flag = (is_color ? CV_LOAD_IMAGE_COLOR :
        CV_LOAD_IMAGE_GRAYSCALE);
    for (int i = 0; i < offsets.size(); ++i){
        int offset = offsets[i];
        for (int file_id = 1; file_id < length+1; ++file_id){
            sprintf(tmp, name_pattern, int(file_id+offset));
            string filename_t = filename + "/" + tmp;
            cv::Mat cv_img_origin = cv::imread(filename_t, cv_read_flag);
            if (!cv_img_origin.data){
                LOG(ERROR) << "Could not load file " << filename_t;
                return false;
            }
            if (height > 0 && width > 0){
                cv::resize(cv_img_origin, cv_img, cv::Size(width, height));
            }else{
                cv_img = cv_img_origin;
            }
            int num_channels = (is_color ? 3 : 1);
            if (file_id==1 && i==0){
                datum->set_channels(num_channels*length*offsets.size());
                datum->set_height(cv_img.rows);
                datum->set_width(cv_img.cols);
                datum->set_label(label);
                datum->clear_data();
                datum->clear_float_data();
                datum_string = datum->mutable_data();
            }
            if (is_color) {
                for (int c = 0; c < num_channels; ++c) {
                  for (int h = 0; h < cv_img.rows; ++h) {
                    for (int w = 0; w < cv_img.cols; ++w) {
                      datum_string->push_back(
                        static_cast<char>(cv_img.at<cv::Vec3b>(h, w)[c]));
                    }
                  }
                }
              } else {  // Faster than repeatedly testing is_color for each pixel w/i loop
                for (int h = 0; h < cv_img.rows; ++h) {
                  for (int w = 0; w < cv_img.cols; ++w) {
                    datum_string->push_back(
                      static_cast<char>(cv_img.at<uchar>(h, w)));
                    }
                  }
              }
        }
    }
    return true;
}

看看该文件中还包含什么:
这里写图片描述
可以加载图片到数据库的很多底层函数都在这里定义。

#include "caffe/util/io.hpp"

从这可以知道我们需要在io.hpp文件中增加ReadSegmentRGBToDatum函数的定义,然后在io.cpp中去实现它。
注意新版caffe的hpp文件目录是:/home/caffe/include/caffe
而cpp所在目录是 :/home/xl/caffe/src/caffe
io.hpp文件内容:

bool ReadSegDataToDatum(const string& img_filename, const string& label_filename, 
                        Datum* datum_data, Datum* datum_label, bool is_color);

bool ReadSegmentFlowToDatum(const string& filename, const int label,
    const vector<int> offsets, const int height, const int width, const int length, Datum* datum, const char* name_pattern);

bool ReadSegmentRGBToDatum(const string& filename, const int label,
    const vector<int> offsets, const int height, const int width, const int length, Datum* datum, bool is_color,
                           const char* name_pattern);

在io.cpp中新增ReadSegmentRGBToDatum函数详解:

    cv::Mat cv_img;
    string* datum_string;
    char tmp[30];
    int cv_read_flag = (is_color ? CV_LOAD_IMAGE_COLOR :
        CV_LOAD_IMAGE_GRAYSCALE);
    for (int i = 0; i < offsets.size(); ++i){
        int offset = offsets[i];
        for (int file_id = 1; file_id < length+1; ++file_id){
            sprintf(tmp, name_pattern, int(file_id+offset));
            string filename_t = filename + "/" + tmp;
            cv::Mat cv_img_origin = cv::imread(filename_t, cv_read_flag);
            if (!cv_img_origin.data){
                LOG(ERROR) << "Could not load file " << filename_t;
                return false;
            }
            if (height > 0 && width > 0){
                cv::resize(cv_img_origin, cv_img, cv::Size(width, height));
            }else{
                cv_img = cv_img_origin;
            }  //读取图片完毕

这几行在前面的官方默认读取函数中,都能找到参考,真正改动大的是下面几行:需要注意的是循环体结构:

    for (int i = 0; i < offsets.size(); ++i){  // offsets.size()=段数
        int offset = offsets[i];
        for (int file_id = 1; file_id < length+1; ++file_id){ //每段里面包含几张图片

最终整个视频分段后所有的数据结构都保存在datum和datum_string中。
说完io之后继续分析video_data_layers.cpp

下面需要实现ShuffleVideos函数,功能就是打乱列表顺序
非常简单:

template <typename Dtype>
void VideoDataLayer<Dtype>::ShuffleVideos(){
    caffe::rng_t* prefetch_rng1 = static_cast<caffe::rng_t*>(prefetch_rng_1_->generator());
    caffe::rng_t* prefetch_rng2 = static_cast<caffe::rng_t*>(prefetch_rng_2_->generator());
    shuffle(lines_.begin(), lines_.end(), prefetch_rng1);
    shuffle(lines_duration_.begin(), lines_duration_.end(),prefetch_rng2);
}

只比image多了两行,image的该函数如下:

template <typename Dtype>
void ImageDataLayer<Dtype>::ShuffleImages() {
  caffe::rng_t* prefetch_rng =
      static_cast<caffe::rng_t*>(prefetch_rng_->generator());
  shuffle(lines_.begin(), lines_.end(), prefetch_rng);
}

下面介绍另一个非常重要的函数InternalThreadEntry,它和solver中的test_interval: 15000看起来很像,实际上表达的意思也是一样,就是循环处理完一次batch_size大小的数据,见其中的主循环体:

for (int item_id = 0; item_id < batch_size; ++item_id){

在image_data中也是同样的处理方法,只是这里我们循环的时候每次处理的是一个视频,而不仅仅是一帧图片。

最后的这两行十分重要:会对层进行注册

INSTANTIATE_CLASS(VideoDataLayer);
REGISTER_LAYER_CLASS(VideoData);

这样我们的数据层函数就构造完了,接下来要对其在prototxt中进行注册和测试。
接下篇:
caffe增加自己的layer实战(下-续1)–caffe学习(13)

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值