ReLU层的实现相对比较简单,第一篇文章里用到的cpu_getting_started.cpp实际上就是一个ReLU的实现。所以我们只要分析一下这部分代码就可以了,正好借着这个机会看看整个mkldnn编程的思路
官方mkldnn的开发文档可以在这里看到 https://intel.github.io/mkl-dnn/
整个代码的流程如下
可以看到整个mkldnn的开发思路都是任何对象(例如内存对象,计算对象)都是先创建一个格式描述对象descriptor,在格式描述里说明了对象的格式参数,然后再通过descriptor创建对象。
内存格式的描述
auto src_md = memory::desc(
{ N, C, H, W }, // logical dims, 应该是传递每一个维度的stride,不是很明白:)
memory::data_type::f32, // 数据类型是32bit浮点
memory::format_tag::nhwc // 内存对象的格式,这里是nhwc, 实际应用里这里通常是nchw
);
ReLU操作的描述
auto relu_d = eltwise_forward::desc(
prop_kind::forward_inference, //操作类型,这里是推理
algorithm::eltwise_relu, //计算方式是relu
src_md, // 输入内存对象的格式描述
0.f, // alpha parameter means negative slope in case of ReLU ReLU的参数,对应x<0.f = 0 如果要是表达x<0.005 = 0的话,这里应该是0.005f
0.f // beta parameter is ignored in case of ReLU 没用
);
engine cpu_engine(engine::kind::cpu, 0); //创建一个基于CPU计算的引擎,mkldnn也许以后会支持gpu, 头文件和一些例子里看到这里可以传gpu参数进去
stream cpu_stream(cpu_engine); //整个mkldnn的处理是基于一个流处理的概念,这里是创建这个流处理对象
relu.execute(
cpu_stream, // 指定在哪个流对象里处理
{ // A map with all inputs and outputs
{ MKLDNN_ARG_SRC, src_mem }, // 指定relu输入的源内存
{ MKLDNN_ARG_DST, dst_mem }, // 指定输出的目标内存
});
//最终relu.execute()开始计算后,要通过调用
cpu_stream.wait();
//来等待整个数据流处理结束才可以从内存对象里读取数据来后处理, 如果不等待就直接处理目标内存对象里的数据,很可能会出错,因为可能会读到未经处理的数据
最后验证一下ReLU计算是否正确
// This is safe since we created `dst_mem` as f32 tensor with known
// memory format.
//通过get_data_handle()来获取内存对象里指向float数据的首指针
float *relu_image = static_cast<float *>(dst_mem.get_data_handle());
// Check the results
for (int n = 0; n < N; ++n)
for (int h = 0; h < H; ++h)
for (int w = 0; w < W; ++w)
for (int c = 0; c < C; ++c) {
int off = offset(n, h, w, c); // get the physical offset of a pixel
float expected = image[off] < 0 ? 0.f : image[off]; // expected value
//这里relu_image[off]为计算结果,跟原始image[off]手动relu计算的expected值做一个比较,看看是否正确
if (relu_image[off] != expected) {
std::stringstream ss;
ss << "Unexpected output at index("
<< n << ", " << c << ", " << h << ", " << w << "): "
<< "Expect " << expected << " "
<< "Got " << relu_image[off];
throw ss.str();
}
}
// [Check the results]
代码奉上,仅供参考
https://github.com/tisandman555/mkldnn_study/blob/master/relu.cpp