sigmoid, tanh, 量化推理
简介
在嵌入式设备,ARM的M系列,或者存硬件实现网络的推理,这时就需要所有的运算都需要用int型(int8,int15)或者自定义的数据类型。这里包括常见的conv2d,devconv2d…等算子,relu,prelu,sigmoid,tanh等非线性激活函数。
预备知识
- 神经网络的量化
- 负数补码
sigmoid/tanh量化推理
这里我们以ARM的CMSIS_5中的代码进行原理和代码的解说。
1. 查表法
sigmoid,tanh及类似的非线性激活函数都是通过查表的方式来进行推理的。这里以int8为例
1.1 表的生成
sigmoid,tanh都会趋于饱和,所以选定浮点的输入范围为 [-8, 8), 其他的数进行clip就行。
* sigmoid(8) = 0.9996646498695336
* tanh(8) = 0.9999997749296758
但是量化模型的输入应该是int8的数据类型才是,[-128,127],那我们的目的应该是通过输入[-128, 127]为索引创建 sigmoid( [-8, 8) )->[]
的映射表。
1.2 输入[-8,8)的转换
索引转换
这里有个很有意思的转换。具体如下:
128 = (uint8)(-128)
129 = (uint8)(-127)
130 = (uint8)(-126)
…
可参考:short | long | char自动类型转换的两个例子
这就是索引的转换。
输入tensor [-8,8)建立一个256的表
用numpy中的 linspace 函数
np.linspace(-8, 8, 256, endpoint=False)
然后再把这所有的256个浮点数,进行sigmoid函数,得到浮点的输出。后续再把浮点的out量化成int8.
def fp2q7(self, x):
x_int = math.floor(x*(2**7)+0.5)
if x_int >= 128 :
x_int = 127
if x_int < -128 :
x_int = -128
if x_int >= 0 :
return x_int
else :
return 0x100 + x_int # 0x100 是什么运算???
array([-8. , -7.9375, -7.875 , -7.8125, -7.75 , -7.6875, -7.625 ,
-7.5625, -7.5 , -7.4375, -7.375 , -7.3125, -7.25 , -7.1875,
-7.125 , -7.0625, -7. , -6.9375, -6.875 , -6.8125, -6.75 ,
-6.6875, -6.625 , -6.5625, -6.5 , -6.4375, -6.375 , -6.3125,
-6.25 , -6.1875, -6.125 , -6.0625, -6. , -5.9375, -5.875 ,
-5.8125, -5.75 , -5.6875, -5.625 , -5.5625, -5.5 , -5.4375,
-5.375 , -5.3125, -5.25 , -5.1875, -5.125 , -5.0625, -5. ,
-4.9375, -4.875 , -4.8125, -4.75 , -4.6875, -4.625 , -4.5625,
-4.5 , -4.4375, -4.375 , -4.3125, -4.25 , -4.1875, -4.125 ,
-4.0625, -4. , -3.9375, -3.875 , -3.8125, -3.75 , -3.6875,
-3.625 , -3.5625, -3.5 , -3.4375, -3.375 , -3.3125, -3.25 ,
-3.1875, -3.125 , -3.0625, -3. , -2.9375, -2.875 , -2.8125,
-2.75 , -2.6875, -2.625 , -2.5625, -2.5 , -2.4375, -2.375 ,
-2.3125, -2.25 , -2.1875, -2.125 , -2.0625, -2. , -1.9375,
-1.875 , -1.8125, -1.75 , -1.6875, -1.625 , -1.5625, -1.5 ,
-1.4375, -1.375 , -1.3125, -1.25 , -1.1875, -1.125 , -1.0625,
-1. , -0.9375, -0.875 , -0.8125, -0.75 , -0.6875, -0.625 ,
-0.5625, -0.5 , -0.4375, -0.375 , -0.3125, -0.25 , -0.1875,
-0.125 , -0.0625, 0. , 0.0625, 0.125 , 0.1875, 0.25 ,
0.3125, 0.375 , 0.4375, 0.5 , 0.5625, 0.625 , 0.6875,
0.75 , 0.8125, 0.875 , 0.9375, 1. , 1.0625, 1.125 ,
1.1875, 1.25 , 1.3125, 1.375 , 1.4375, 1.5 , 1.5625,
1.625 , 1.6875, 1.75 , 1.8125, 1.875 , 1.9375, 2. ,
2.0625, 2.125 , 2.1875, 2.25 , 2.3125, 2.375 , 2.4375,
2.5 , 2.5625, 2.625 , 2.6875, 2.75 , 2.8125, 2.875 ,
2.9375, 3. , 3.0625, 3.125 , 3.1875, 3.25 , 3.3125,
3.375 , 3.4375, 3.5 , 3.5625, 3.625 , 3.6875, 3.75 ,
3.8125, 3.875 , 3.9375, 4. , 4.0625, 4.125 , 4.1875,
4.25 , 4.3125, 4.375 , 4.4375, 4.5 , 4.5625, 4.625 ,
4.6875, 4.75 , 4.8125, 4.875 , 4.9375, 5. , 5.0625,
5.125 , 5.1875, 5.25 , 5.3125, 5.375 , 5.4375, 5.5 ,
5.5625, 5.625 , 5.6875, 5.75 , 5.8125, 5.875 , 5.9375,
6. , 6.0625, 6.125 , 6.1875, 6.25 , 6.3125, 6.375 ,
6.4375, 6.5 , 6.5625, 6.625 , 6.6875, 6.75 , 6.8125,
6.875 , 6.9375, 7. , 7.0625, 7.125 , 7.1875, 7.25 ,
7.3125, 7.375 , 7.4375, 7.5 , 7.5625, 7.625 , 7.6875,
7.75 , 7.8125, 7.875 , 7.9375])
但是,这个表没法索引。将上面的向量量化到int8为:
array([-128., -127., -126., -125., -124., -123., -122., -121., -120.,
-119., -118., -117., -116., -115., -114., -113., -112., -111.,
-110., -109., -108., -107., -106., -105., -104., -103., -102.,
-101., -100., -99., -98., -97., -96., -95., -94., -93.,
-92., -91., -90., -89., -88., -87., -86., -85., -84.,
-83., -82., -81., -80., -79., -78., -77., -76., -75.,
-74., -73., -72., -71., -70., -69., -68., -67., -66.,
-65., -64., -63., -62., -61., -60., -59., -58., -57.,
-56., -55., -54., -53., -52., -51., -50., -49., -48.,
-47., -46., -45., -44., -43., -42., -41., -40., -39.,
-38., -37., -36., -35., -34., -33., -32., -31., -30.,
-29., -28., -27., -26., -25., -24., -23., -22., -21.,
-20., -19., -18., -17., -16., -15., -14., -13., -12.,
-11., -10., -9., -8., -7., -6., -5., -4., -3.,
-2., -1., 0., 1., 2., 3., 4., 5., 6.,
7., 8., 9., 10., 11., 12., 13., 14., 15.,
16., 17., 18., 19., 20., 21., 22., 23., 24.,
25., 26., 27., 28., 29., 30., 31., 32., 33.,
34., 35., 36., 37., 38., 39., 40., 41., 42.,
43., 44., 45., 46., 47., 48., 49., 50., 51.,
52., 53., 54., 55., 56., 57., 58., 59., 60.,
61., 62., 63., 64., 65., 66., 67., 68., 69.,
70., 71., 72., 73., 74., 75., 76., 77., 78.,
79., 80., 81., 82., 83., 84., 85., 86., 87.,
88., 89., 90., 91., 92., 93., 94., 95., 96.,
97., 98., 99., 100., 101., 102., 103., 104., 105.,
106., 107., 108., 109., 110., 111., 112., 113., 114.,
115., 116., 117., 118., 119., 120., 121., 122., 123.,
124., 125., 126., 127.])
然后前面又所说了一个有意思的转换。
128 = (uint8)(-128)
129 = (uint8)(-127)
130 = (uint8)(-126)
…
256 = (uint8)(-1)
所以我们将上表的次序调整一下,将负数按从小到大接到正数后面。为 [0~127, -128 ~ -1] ,这样整个表就建立好了。
0x100 + x_int 负数为什么要 + 0x100
table_gen.py中sigmoid输出的浮点值需要量化成int8,然后写入到C文件中用16进制,具体的处理代码:
def fp2q7(self, x):
x_int = math.floor(x*(2**7)+0.5)
if x_int >= 128 :
x_int = 127
if x_int < -128 :
x_int = -128
if x_int >= 0 :
return x_int
else :
return 0x100 + x_int
这是因为16进制,和二进制是等价的,在计算机中正数以原码,负数以补码的形式存储,int8类型负数的补码就是 256+x(x是一个负数)。下面举几个例子:
如果sigmoid的输出量化后为-128
以0x%02x的16进制格式写入,**因为0x只能写入正数**,所以我们求出二进制的补码,将该补码当原码写入。
所以这里就是求负数的补码表示的正数。
在c中直接:
uint(-128) = 128
在python中:
0x100+(-128)= 128
128因为是正数,二进制原码为:10000000
-128的补码:1000 0000
可以看到二者是一样的
计算
CMSIS中的计算code
void arm_nn_activations_direct_q7(q7_t *data, uint16_t size, uint16_t int_width, arm_nn_activation_type type)
{
uint16_t i = size;
q7_t *pIn = data;
q7_t *pOut = data;
q7_t in;
q7_t out;
uint16_t shift_size = 3 - int_width;
const q7_t *lookup_table;
switch (type)
{
case ARM_SIGMOID:
lookup_table = sigmoidTable_q7;
break;
case ARM_TANH:
default:
lookup_table = tanhTable_q7;
break;
}
while (i)
{
in = *pIn++;
out = lookup_table[(uint8_t)(in >> shift_size)];
*pOut++ = out;
i--;
}
}
解释
- 这里的int_width默认为3,这是和前面输入的范围[-8, 8)对应的。而且建立的表的输入范围为[-8,8),这个的int_width是用来控制输入的范围的。int_width=2, 表示输入的范围为[-4, 4), int_width=1, 输入的范围为[-1, 1),同时可以用原来的表。
- (uint8_t)(in >> shift_size),shift_size默认为0,于是直接就是(uint8_t)in,这和前面将的次序调整有关。