ncnn, bfloat16
1. 前言
首先可以看作者的博客用bf16加速ncnn
然后理解什么是bfloat16,维基百科
接下来从c,neon,assembly三个方面看ncnn是如何用bfloat16的
总体思路就是将float32数据类型的二进制表示中的后16直接砍掉
2. c
首先看c中bfloat16和float32之间的转换。
inline unsigned short float32_to_bfloat16(float value)
{
// 16 : 16
union
{
unsigned int u;
float f;
} tmp;
tmp.f = value;
return tmp.u >> 16;
}
// convert brain half to float
inline float bfloat16_to_float32(unsigned short value)
{
// 16 : 16
union
{
unsigned int u;
float f;
} tmp;
tmp.u = value << 16;
return tmp.f;
}
解释
- unsigned int, float都是32bit的。这里采用union,使不同类型的变量能访问内存中的相同位置
- 比如
float32_to_bfloat16
中,将float32的二进制表示向右移动16bit,然后通过unsigned short返回高16bit。这里返回的unsigned short变量。没有任何实际意义,ncnn中也没有计算,需要再次转换成float32才能用于计算 bfloat16_to_float32
中,将上面个的变量右移16bit,再按float进行解析出tmp.f,这时候才能用于计算。
3. neon
bfloat16和float32转换部分
inline uint16x4_t vcvt_bf16_f32(float32x4_t _v)
{
return vshrn_n_u32(vreinterpretq_u32_f32(_v), 16);
}
inline float32x4_t vcvt_f32_bf16(uint16x4_t _v)
{
return vreinterpretq_f32_u32(vshll_n_u16(_v, 16));
}
hardsigmoid运算例子部分
float32x4_t _zero = vdupq_n_f32(0.f);
float32x4_t _one = vdupq_n_f32(1.f);
while (nn--)
{
float32x4_t _p = vcvt_f32_bf16(vld1_u16(ptr));
float32x4_t _ans = vdupq_n_f32(beta);
_ans = vmlaq_n_f32(_ans, _p, alpha);
_ans = vmaxq_f32(_ans, _zero);
_ans = vminq_f32(_ans, _one);
vst1_u16(ptr, vcvt_bf16_f32(_ans));
ptr += 4;
}
解释
vreinterpretq_f32_u32(vshll_n_u16(_v, 16));
bfloat16转float32,还是右移16bit,转成uint32,再uint32转float32。- 其核心计算部分还是采用的float32,只是中间tensor采用了bfloat16,可能急速在数据载入的时候
4. assembly
这里采用relu_arm的核心代码作为例子
"prfm pldl1keep, [%0, #256] \n"
"ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [%0] \n"
"shll v0.4s, v0.4h, #16 \n"
"shll v1.4s, v1.4h, #16 \n"
"shll v2.4s, v2.4h, #16 \n"
"shll v3.4s, v3.4h, #16 \n"
"fmax v0.4s, v0.4s, v16.4s \n"
"fmax v1.4s, v1.4s, v16.4s \n"
"fmax v2.4s, v2.4s, v16.4s \n"
"fmax v3.4s, v3.4s, v16.4s \n"
"shrn v0.4h, v0.4s, #16 \n"
"shrn v1.4h, v1.4s, #16 \n"
"shrn v2.4h, v2.4s, #16 \n"
"shrn v3.4h, v3.4s, #16 \n"
解释:
"shll v0.4s, v0.4h, #16 \n"
:就是uint16右移16bit变成32bit- 然后,汇编中没有uint32转float32的操作,前面neon是有vreinterpretq_f32_u32函数的,但是其是没有对应的汇编代码的。
- 关键是后面的
fmax
,该汇编指令是针对float32数据类型的,umax
才是针对int数据类型的。因为uint32和float32都是32bit,只是不同数据类型解析的方式不同。 "shrn v0.4h, v0.4s, #16 \n"
再右移16bit
5. 看作者的测试结果
来自:用bf16加速ncnn
总结
- 采用bfloat16做中间tensor的保存类型,核心计算的时候还是采用float32类型。速度的提升主要来源于op数据载入,写入上。
- 所以其提升有限,和支持float16的方式应该有一段的差距。