接下来实现基于MKL-DNN的BN层
先复习一下BatchNorm层的计算公式
这个公式里mean为均值,var为方差,scale为缩放因子,shift是平移值
epsilon为一个常数,通常默认值为0.001或者可以在网络模型里读出来,主要用来防止出现方差var=0的时候出现除零错误而添加进去的
在 纯C++超分辨率重建DRRN --改编--(二)归一化(BatchNorm) 和 缩放和平移(Scale)里把这个操作分解为2个函数。在mkl-dnn的API里,我们可以用batch_normalization_forward来一步完成。
实现代码
const int N = 1, H = 5, W = 5, C = 2;
const int IC = C, OC = IC, KH = 3, KW = 3;
// The image size
const int image_size = N * H * W * C;
const int bn_mean_size = C;
const int bn_scale_shift_size = 2 * C;
// Allocate a buffer for the image
std::vector<float> image(image_size);
std::vector<float> mean(bn_mean_size);
std::vector<float> var(bn_mean_size);
std::vector<float> scale_shift(bn_scale_shift_size);
// 初始化输入数据
for (int n = 0; n < N; n++)
for (int h = 0; h < H; h++)
for (int w = 0; w < W; w++)
for (int c = 0; c < C; c++) {
int off = offset(n, h, w, c); // Get the physical offset of a pixel
image[off] = off;
}
//初始化均值和方差,C=0时mean=1,var=4; C=1时mean=0,var=4
for (int n = 0; n < bn_mean_size; n++)
{
mean[n] = 1.0;
var[n] = 4.0;
if (n >= bn_mean_size / 2)
{
mean[n] = 0.0;
var[n] = 4.0;
}
}
//初始化scale和shift, 统一放到一个数组里,数组前半部分为scale, 后半部分为shift
for (int n = 0; n < bn_scale_shift_size; n++)
{
scale_shift[n] = 2; //scale
if (n >= bn_scale_shift_size / 2)
{
scale_shift[n] = -1; //shift
}
}
memory::dims conv3_src_tz = { N, C, H, W };
memory::dims conv3_mean_tz = { C };
memory::dims conv3_scale_shift_tz = { 2, C };
// [Init src_md]
auto user_src3_md = memory::desc(
conv3_src_tz, // logical dims, the order is defined by a primitive
memory::data_type::f32, // tensor's data type
memory::format_tag::nhwc // memory format, NHWC in this case 这里控制memory的layout
);
// create user memory
auto user_conv3_src_mem = memory(user_src3_md, cpu_engine, image.data());
/*********************** Batch Normal ***************************************************/
auto conv3_mean_md = memory::desc(conv3_mean_tz, memory::data_type::f32, memory::format_tag::x);
auto conv3_scale_shift_md = memory::desc(conv3_scale_shift_tz, memory::data_type::f32, memory::format_tag::nc);
auto mean_mem = memory(conv3_mean_md, cpu_engine, mean.data());
auto var_mem = memory(conv3_mean_md, cpu_engine, var.data());
auto scale_shift_mem = memory(conv3_scale_shift_md, cpu_engine, scale_shift.data());
auto user_bn_dst_mem = memory(user_src3_md, cpu_engine);
//这个flags控制BN做哪些操作, 重点是有use_scale_shift时做BatchNorm+ScaleShift,
//没有use_scale_shift标志的话只做BatchNorm
normalization_flags flags = normalization_flags::use_global_stats | normalization_flags::use_scale_shift;
// set flags for different flavors (use | to combine flags)
// use_global_stats -- do not compute mean and variance in the primitive, user has to provide them
// use_scale_shift -- in addition to batch norm also scale and shift the result
auto bnrm_fwd_d = batch_normalization_forward::desc(
prop_kind::forward_inference, // might be forward_inference, backward, backward_data
user_src3_md, // data descriptor (i.e. sizes, data type, and layout)
0.001f, // eps在这里定义
flags);
auto bnrm_fwd_pd = batch_normalization_forward::primitive_desc(bnrm_fwd_d, cpu_engine);
auto bnrm_fwd = batch_normalization_forward(bnrm_fwd_pd);
bnrm_fwd.execute(
cpu_stream,
{
{ MKLDNN_ARG_SRC, user_conv3_src_mem },
{ MKLDNN_ARG_MEAN, mean_mem },
{ MKLDNN_ARG_VARIANCE, var_mem },
{ MKLDNN_ARG_SCALE_SHIFT, scale_shift_mem },
{ MKLDNN_ARG_DST, user_bn_dst_mem }
}
);
/**************************************************************************/
// Wait the stream to complete the execution
cpu_stream.wait();
上面这部分代码,在初始化mean/var值时,针对C的不同定义了不同的mean/var。当C=0的时候mean=1, var=4 ;当C=1的时候mean=0,var=4。
BatchNorm的计算公式为 x_norm=(x-mean)/sqrt(var+eps),通过前面的定义的特殊的mean和var值,可以得到
C=0时,x_norm约等于(x-1)/2
C=1时,x_norm约等于x/2 ----这里约等于的意思是因为有个eps=0.001f在里面,所以sqrt(4+eps)出来的值不是精确的2
最终的输出结果
Y=x_norm*scale+shift
代码中定义Scale=2, shift=-1
因此 Y=x_norm*2-1
运行一下程序,看看是不是和我们理论计算的一样
C=0时 x=0, x_norm=(0-1)/2=-0.5, Y=(-0.5)*2-1=-2 (约等于,没有算eps)
x=2, x_norm=0.5 Y=0
C=1时 x=1, x_norm=1/2, Y=1/2*2-1=0 (约等于)
x=3, x_norm=0.5 Y=2
计算结果和实际输出基本一致,再把这里的eps值改为0
auto bnrm_fwd_d = batch_normalization_forward::desc(
prop_kind::forward_inference, // might be forward_inference, backward, backward_data
user_src3_md, // data descriptor (i.e. sizes, data type, and layout)
0.000f, // eps
flags);
就跟理论计算完全一样了 :)
到这里,BatchNorm层也实现了
最后代码奉上,仅供参考
https://github.com/tisandman555/mkldnn_study/blob/master/bn.cpp