(本文首发于公众号,没事来逛逛)
今天讲一点量化训练中关于 STE (Straight Through Estimator) 的问题,同时介绍两种应对问题的方法:DSQ 和 QuantNoise。分别对应两篇论文:Differentiable Soft Quantization: Bridging Full-Precision and Low-Bit Neural Networks 和 Training with Quantization Noise for Extreme Model Compression。
阅读本文需要对量化训练的过程有基本了解,可以参考我之前的这篇文章。
STE的问题
在量化训练中,由于 round 函数的存在,我们无法正常求导,因此退而求其次,在反向传播的时候用 STE 跳过了这个函数。这个「跳过」,就是把 STE 的导数默认为 1。
但这种做法有个副作用,由于它无法反应真实的量化误差,所以,不管量化位数有多少 (8 比特、4 比特等等),导数都是一样的。
看下面这个例子:
class QuantConv(nn.Module):
def __init__(self, conv_module, bits=8):
super(QuantConv, self).__init__()
self.conv_module = conv_module
self.bits = bits
def forward(self, x):
scale, zero_point = calcScaleZeroPoint(self.conv_module.weight.data.min(), \
self.conv_module.weight.data.max(), num_bits=self.bits)
weight, bias = self.conv_module.weight, self.conv_module.bias
# 对weight做伪量化,模拟量化误差
quant_weight = dequantize_tensor(quantize_tensor(weight, scale, zero_point, self.bits), scale, zero_point)
# detach这一步就是STE
return F.conv2d(x, weight + (quant_weight - weight).detach(), bias, 3, 1)
我定义了一个量化的卷积 QuantConv
,对 weight 做了伪量化,其中 calcScaleZeroPoint
、quantize_tensor
、dequantize_tensor
这几个函数的定义可以在之前的文章中找到。
然后,我们用不同的比特数来量化,看看在 BP 的时候,梯度有什么差别:
conv = nn.Conv2d(1, 1, 3, 1)
x = torch.randn((1, 1, 4, 4)) # 使用同一个输入
quantconv = QuantConv(conv)
a = quantconv(x).sum().backward() # BP计算梯度
print("use 8 bit")
print(quantconv.conv_module.weight.grad)
quantconv.zero_grad()
quantconv.bits = 2
a = quantconv(x).sum().backward() # BP计算梯度
print("use 2 bit")
print(quantconv.conv_module.weight.grad)
输出结果如下:
use 8 bit
tensor([[[[ 0.6101, -2.7252, -0.2428],
[ 2.2399, 0.5673, 1.7511],
[-0.5968, 1.2209, 0.6866]]]])
use 2 bit
tensor([[[[ 0.6101, -2.7252, -0.2428],
[ 2.2399, 0.5673, 1.7511],
[-0.5968, 1.2209, 0.6866]]]])
可以发现,对同一个输入,用同样的损失函数计算梯度,不同比特数量化得到的梯度是一样的!但不同比特数带来的量化误差明显有很大差异,This is unreasonable!
当然,这个例子的 loss 比较取巧,如果用其他 loss (比如交叉熵函数),可能梯度就不会一样了。但不管是哪种 loss,到 STE 这一步就仿佛一套组合拳打在棉花上,最重要的梯度信息都扔掉了。这里面的原因就在于 STE 根本无法体现量化的损失。在低比特量化的时候,这种副作用尤其明显 (所以 QAT 在低比特训练中尤其困难,模型权重根本训不动)。
DSQ
基本思想
为了解决这个问题,一个很直接的想法是用某个可导的函数来近似 round,从而避免使用 STE。
比如说,我们知道傅立叶级数可以近似任何周期函数:
(图片摘自:https://www.zhihu.com/search?q=%E5%82%85%E7%AB%8B%E5%8F%B6%E5%8F%98%E6%8D%A2%E4%B9%8B%E6%8E%90%E6%AD%BB%E6%95%99%E7%A8%8B&utm_content=search_suggestion&type=content)
如果把 round 当成一个周期函数,那我们就可以用傅立叶级数来逼近 round 了,而傅立叶级数是可以求导的。
或者,我们也可以对 round 函数进行泰勒展开,用多项式来近似。
又或者,我们知道神经网络本身可以模拟任何函数,因此甚至可以用一个神经网络来近似 round。
不过,以上这些想法都过于复杂,计算量巨大,操作起来比较困难。
而 DSQ 做的就是引入一个相对简单的函数来模拟 round,做到计算简单,同时尽可能逼近 round 函数。
这个