神经网络低比特量化——LSQ
本文为IBM的量化工作,发表在ICLR 2020。论文题目:Learned Step Size Quantization。为了解决量化精度越低,模型识别率越低的问题,本文引入了一种新的手段来估计和扩展每个权重和激活层的量化器步长大小的任务损失梯度,并在 ImageNet 上的实验和分析证明了所提出的方法的有效性, 实现了ResNet 4 bit量化不掉精度!
- 论文链接:Learned Step Size Quantization
- 源码链接(非官方复现):https://github.com/zhutmost/lsq-net
摘要
在推理时以低精度操作运行的深度网络比高精度具有功耗和存储优势,但需要克服随着精度降低而保持高精度的挑战。在这里,本文提出了一种训练此类网络的方法,即 Learned Step Size Quantization,当使用来自各种架构的模型时,该方法在 ImageNet 数据集上实现了 SOTA 的精度,其权重和激活量化为2、3或4 bit 精度,并且可以训练达到全精度基线精度的3 bit 模型。本文的方法建立在现有的量化网络中学习权重的方法基础上,通过改进量化器本身的配置方式。具体来说,本文引入了一种新的手段来估计和扩展每个权重和激活层的量化器步长大小的任务损失梯度,这样它就可以与其他网络参数一起学习。这种方法可以根据给定系统的需要使用不同的精度水平工作,并且只需要对现有的训练代码进行简单的修改。
方法
量化计算公式
v ˉ = ⌊ cip ( v / s , − Q N , Q P ) ⌉ \bar{v}=\left\lfloor\operatorname{cip}\left(v / s,-Q_{N}, Q_{P}\right)\right\rceil vˉ=⌊cip(v/s,−QN,QP)⌉
v ^ = v ˉ × s \hat{v}=\bar{v} \times s v^=vˉ×s
- s为量化的 STEP SIZE 可学习参数。s即是数据的缩放因子,又能控制数据截断的边界。
- 针对weights: Q N = 2 b − 1 and Q P = 2 b − 1 − 1 Q_{N}=2^{b-1} \text { and } Q_{P}=2^{b-1}-1 QN=2b−1 and QP=2b−1−1
- 针对data: Q N = 0 and Q P = 2 b − 1 Q_{N}=0 \text { and } Q_{P}=2^{b}-1 QN=0 and QP=2b−1
STEP SIZE GRADIENT
∂ v ^ ∂ s = { − v / s + ⌊ v / s ⌉ if − Q N < v / s < Q P − Q N if v / s ≤ − Q N Q P if v / s ≥ Q P \frac{\partial \hat{v}}{\partial s}=\left\{\begin{array}{ll} -v / s+\lfloor v / s\rceil & \text { if }-Q_{N} < v / s < Q_{P} \\ -Q_{N} & \text { if } v / s \leq-Q_{N} \\ Q_{P} & \text { if } v / s \geq Q_{P} \end{array}\right. ∂s∂v^=⎩⎨⎧−v/s+⌊v/s⌉−QNQP if −QN<v/s<QP if v/s≤−QN if v/s≥QP
STEP SIZE GRADIENT SCALE
当量化比特数增加时,step-size会变小,以确保更为精细的量化;而当量化比特数减少时,step-size会变大。为了让step-size的参数更新,能够适应量化比特数的调整,需要将step-size的梯度乘以一个scale系数。
- 权重: g = 1 / N W Q P g=1 / \sqrt{N_{W} Q_{P}} g=1/NWQP , N W N_{W} NW代表当前层的权重数。
- 激活: g = 1 / N F Q P g=1 / \sqrt{N_{F} Q_{P}} g=1/NFQP , N F N_{F} NF代表当前层的特征数。
直通估计器
量化的权重和激活用于前向和反向传递,通过 Bengio 提出的直通估计器(STE)计算,如下公式:
∂
v
^
∂
v
=
{
1
if
−
Q
N
<
v
/
s
<
Q
P
0
otherwise
\frac{\partial \hat{v}}{\partial v}=\left\{\begin{array}{ll} 1 & \text { if }-Q_{N} < v / s < Q_{P} \\ 0 & \text { otherwise } \end{array}\right.
∂v∂v^={10 if −QN<v/s<QP otherwise