MKL-DNN学习笔记 (六) 实现BatchNorm层

接下来实现基于MKL-DNN的BN层

先复习一下BatchNorm层的计算公式

这个公式里mean为均值,var为方差,scale为缩放因子,shift是平移值

epsilon为一个常数,通常默认值为0.001或者可以在网络模型里读出来,主要用来防止出现方差var=0的时候出现除零错误而添加进去的

纯C++超分辨率重建DRRN --改编--(二)归一化(BatchNorm) 和 缩放和平移(Scale)里把这个操作分解为2个函数。在mkl-dnn的API里,我们可以用batch_normalization_forward来一步完成。

实现代码

	const int N = 1, H = 5, W = 5, C = 2;
	const int IC = C, OC = IC, KH = 3, KW = 3;

	// The image size
	const int image_size = N * H * W * C;
	const int bn_mean_size = C;
	const int bn_scale_shift_size = 2 * C;

	// Allocate a buffer for the image
	std::vector<float> image(image_size);
	std::vector<float> mean(bn_mean_size);
	std::vector<float> var(bn_mean_size);
	std::vector<float> scale_shift(bn_scale_shift_size);

	// 初始化输入数据
	for (int n = 0; n < N; n++)
		for (int h = 0; h < H; h++)
			for (int w = 0; w < W; w++)
				for (int c = 0; c < C; c++) {
					int off = offset(n, h, w, c); // Get the physical offset of a pixel
					image[off] = off;
				}
    //初始化均值和方差,C=0时mean=1,var=4; C=1时mean=0,var=4
	for (int n = 0; n < bn_mean_size; n++)
	{
		mean[n] = 1.0;
		var[n] = 4.0;
		if (n >= bn_mean_size / 2)
		{
			mean[n] = 0.0;
			var[n] = 4.0;
		}
	}
    //初始化scale和shift, 统一放到一个数组里,数组前半部分为scale, 后半部分为shift
	for (int n = 0; n < bn_scale_shift_size; n++)
	{
		scale_shift[n] = 2;    //scale
		if (n >= bn_scale_shift_size / 2)
		{
			scale_shift[n] = -1;    //shift
		}
	}


	memory::dims conv3_src_tz = { N, C, H, W };
	memory::dims conv3_mean_tz = { C };
	memory::dims conv3_scale_shift_tz = { 2, C };

	// [Init src_md]
	auto user_src3_md = memory::desc(
		conv3_src_tz, // logical dims, the order is defined by a primitive
		memory::data_type::f32,     // tensor's data type
		memory::format_tag::nhwc    // memory format, NHWC in this case 这里控制memory的layout
	);

	// create user memory
	auto user_conv3_src_mem = memory(user_src3_md, cpu_engine, image.data());

	/*********************** Batch Normal ***************************************************/

	auto conv3_mean_md = memory::desc(conv3_mean_tz, memory::data_type::f32, memory::format_tag::x);
	auto conv3_scale_shift_md = memory::desc(conv3_scale_shift_tz, memory::data_type::f32, memory::format_tag::nc);

	auto mean_mem = memory(conv3_mean_md, cpu_engine, mean.data());
	auto var_mem = memory(conv3_mean_md, cpu_engine, var.data());
	auto scale_shift_mem = memory(conv3_scale_shift_md, cpu_engine, scale_shift.data());

	auto user_bn_dst_mem = memory(user_src3_md, cpu_engine);

    //这个flags控制BN做哪些操作, 重点是有use_scale_shift时做BatchNorm+ScaleShift, 
    //没有use_scale_shift标志的话只做BatchNorm
	normalization_flags flags = normalization_flags::use_global_stats | normalization_flags::use_scale_shift;
	// set flags for different flavors (use | to combine flags)
	//      use_global_stats -- do not compute mean and variance in the primitive, user has to provide them
	//      use_scale_shift  -- in addition to batch norm also scale and shift the result

	auto bnrm_fwd_d = batch_normalization_forward::desc(
		prop_kind::forward_inference, // might be forward_inference, backward, backward_data
		user_src3_md,  // data descriptor (i.e. sizes, data type, and layout)
		0.001f,     // eps在这里定义
		flags);


	auto bnrm_fwd_pd = batch_normalization_forward::primitive_desc(bnrm_fwd_d, cpu_engine);
	auto bnrm_fwd = batch_normalization_forward(bnrm_fwd_pd);
	bnrm_fwd.execute(
		cpu_stream,
		{
			{ MKLDNN_ARG_SRC, user_conv3_src_mem },
			{ MKLDNN_ARG_MEAN, mean_mem },
			{ MKLDNN_ARG_VARIANCE, var_mem },
			{ MKLDNN_ARG_SCALE_SHIFT, scale_shift_mem },
			{ MKLDNN_ARG_DST, user_bn_dst_mem }
		}
	);

	/**************************************************************************/
	// Wait the stream to complete the execution
	cpu_stream.wait();

上面这部分代码,在初始化mean/var值时,针对C的不同定义了不同的mean/var。当C=0的时候mean=1, var=4 ;当C=1的时候mean=0,var=4。

BatchNorm的计算公式为 x_norm=(x-mean)/sqrt(var+eps),通过前面的定义的特殊的mean和var值,可以得到

C=0时,x_norm约等于(x-1)/2

C=1时,x_norm约等于x/2    ----这里约等于的意思是因为有个eps=0.001f在里面,所以sqrt(4+eps)出来的值不是精确的2

最终的输出结果

Y=x_norm*scale+shift

代码中定义Scale=2, shift=-1

因此 Y=x_norm*2-1

 

运行一下程序,看看是不是和我们理论计算的一样

C=0时 x=0, x_norm=(0-1)/2=-0.5, Y=(-0.5)*2-1=-2 (约等于,没有算eps)

                   x=2, x_norm=0.5 Y=0

C=1时 x=1, x_norm=1/2,  Y=1/2*2-1=0  (约等于)

                   x=3, x_norm=0.5 Y=2

 

计算结果和实际输出基本一致,再把这里的eps值改为0

    auto bnrm_fwd_d = batch_normalization_forward::desc(
        prop_kind::forward_inference, // might be forward_inference, backward, backward_data
        user_src3_md,  // data descriptor (i.e. sizes, data type, and layout)
        0.000f,     // eps
        flags);

就跟理论计算完全一样了  :)

 

到这里,BatchNorm层也实现了

 

最后代码奉上,仅供参考

https://github.com/tisandman555/mkldnn_study/blob/master/bn.cpp

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值