本文主要分析caffe
源码分析-Blob
,主要如下几个方面:
-
overview整体上了解caffe的Blob
-
Blob 成员变量
-
Blob
主要函数,核心在于Blob的使用实例以及其与opencv
Mat
的操作的相互转化(附带运行结果基于CLion
)
overview
Blob
是Caffe
作为数据传输的媒介,无论是网络权重参数,还是输入数据,都是转化为Blob
数据结构来存储,网络,求解器等都是直接与此结构打交道的。
其直观的可以把它看成一个有4维的结构体(包含数据和梯度),而实际上,它们只是一维的指针而已,其4维结构通过shape属性得以计算出来(根据C语言的数据顺序)。
Blob在也不一定全是4维的,例如全连接层的参数就没有用四维,后期的版本已经deprecated,而是直接用
vector<int> shape_
成员变量
Blob中的主要数据成员如下,实际是在SyncedMemory上做了一层包装(SyncedMemory介绍见上一篇blog):
protected:
shared_ptr<SyncedMemory> data_; //存储前向传递数据
shared_ptr<SyncedMemory> diff_; //存储反向传递梯度
shared_ptr<SyncedMemory> shape_data_;// 参数维度old version
vector<int> shape_; //参数维度
int count_; //Blob存储的元素个数(shape_所有元素乘积)
int capacity_;//当前Blob的元素个数(控制动态分配)
主要函数
主要分析如下几类函数:
-
构造函数, 以及Reshape函数()
-
索引、返回N、C、H、W相关函数
-
gpu、cpu同步函数, 以及数据的获取
-
简单的数据处理如scale_data对数据缩放(底层调用了cblas库的运算)
-
Blob的示例,数据赋值以及和opencv Mat的操作
-
Blob对应的protobuf结构体BlobShape、BlobProto、BlobProtoVector
1. 构造函数, 以及Reshape函数()
构造函数分类两种类型:
-
默认的什么参数
-
传入N、C、H、W构造,最终调用Reshape函数
Blob() //构造函数:初始化列表 {空函数体}
: data_(), diff_(), count_(0), capacity_(0) {}
// @brief Deprecated; use <code>Blob(const vector<int>& shape)</code>.
explicit Blob(const int num, const int channels, const int height,
const int width); //可以通过设置数据维度(N,C,H,W)初始化
// 也可以通过传入vector<int>直接传入维数
explicit Blob(const vector<int>& shape);
Blob() //构造函数:初始化列表 {空函数体}
: data_(), diff_(), count_(0), capacity_(0) {}
// @brief Deprecated; use <code>Blob(const vector<int>& shape)</code>.
explicit Blob(const int num, const int channels, const int height,
const int width); //可以通过设置数据维度(N,C,H,W)初始化
// 也可以通过传入vector<int>直接传入维数
explicit Blob(const vector<int>& shape);
下面重点看下Reshape函数
template <typename Dtype>
void Blob<Dtype>::Reshape(const int num, const int channels, const int height,
const int width) {
vector<int> shape(4);
shape[0] = num;
shape[1] = channels;
shape[2] = height;
shape[3] = width;
Reshape(shape);
}
// 完成blob形状shape_的记录,大小count_的计算,合适大小capacity_存储的申请
template <typename Dtype>
void Blob<Dtype>::Reshape(const vector<int>& shape) {
CHECK_LE(shape.size(), kMaxBlobAxes);
count_ = 1;
shape_.resize(shape.size());
if (!shape_data_ || shape_data_->size() < shape.size() * sizeof(int)) {
shape_data_.reset(new SyncedMemory(shape.size() * sizeof(int)));
}
int* shape_data = static_cast<int*>(shape_data_->mutable_cpu_data());
for (int i = 0; i < shape.size(); ++i) {
CHECK_GE(shape[i], 0);
CHECK_LE(shape[i], INT_MAX / count_) << "blob size exceeds INT_MAX";
count_ *= shape[i];
shape_[i] =