量化训练之补偿STE:DSQ和QuantNoise

(本文首发于公众号,没事来逛逛)

今天讲一点量化训练中关于 STE (Straight Through Estimator) 的问题,同时介绍两种应对问题的方法:DSQ 和 QuantNoise。分别对应两篇论文:Differentiable Soft Quantization: Bridging Full-Precision and Low-Bit Neural NetworksTraining with Quantization Noise for Extreme Model Compression

阅读本文需要对量化训练的过程有基本了解,可以参考我之前的这篇文章

STE的问题

在量化训练中,由于 round 函数的存在,我们无法正常求导,因此退而求其次,在反向传播的时候用 STE 跳过了这个函数。这个「跳过」,就是把 STE 的导数默认为 1。

但这种做法有个副作用,由于它无法反应真实的量化误差,所以,不管量化位数有多少 (8 比特、4 比特等等),导数都是一样的。

看下面这个例子:

class QuantConv(nn.Module):

    def __init__(self, conv_module, bits=8):
        super(QuantConv, self).__init__()
        self.conv_module = conv_module
        self.bits = bits

    def forward(self, x):
        scale, zero_point = calcScaleZeroPoint(self.conv_module.weight.data.min(), \
            self.conv_module.weight.data.max(), num_bits=self.bits)
        weight, bias = self.conv_module.weight, self.conv_module.bias
        # 对weight做伪量化,模拟量化误差
        quant_weight = dequantize_tensor(quantize_tensor(weight, scale, zero_point, self.bits), scale, zero_point)
        # detach这一步就是STE
        return F.conv2d(x, weight + (quant_weight - weight).detach(), bias, 3, 1)

我定义了一个量化的卷积 QuantConv,对 weight 做了伪量化,其中 calcScaleZeroPointquantize_tensordequantize_tensor 这几个函数的定义可以在之前的文章中找到。

然后,我们用不同的比特数来量化,看看在 BP 的时候,梯度有什么差别:

conv = nn.Conv2d(1, 1, 3, 1)
x = torch.randn((1, 1, 4, 4))  # 使用同一个输入
quantconv = QuantConv(conv)

a = quantconv(x).sum().backward()   # BP计算梯度
print("use 8 bit")
print(quantconv.conv_module.weight.grad)

quantconv.zero_grad()
quantconv.bits = 2
a = quantconv(x).sum().backward()    # BP计算梯度
print("use 2 bit")
print(quantconv.conv_module.weight.grad)

输出结果如下:

use 8 bit
tensor([[[[ 0.6101, -2.7252, -0.2428],
          [ 2.2399,  0.5673,  1.7511],
          [-0.5968,  1.2209,  0.6866]]]])
use 2 bit
tensor([[[[ 0.6101, -2.7252, -0.2428],
          [ 2.2399,  0.5673,  1.7511],
          [-0.5968,  1.2209,  0.6866]]]])

可以发现,对同一个输入,用同样的损失函数计算梯度,不同比特数量化得到的梯度是一样的!但不同比特数带来的量化误差明显有很大差异,This is unreasonable!

当然,这个例子的 loss 比较取巧,如果用其他 loss (比如交叉熵函数),可能梯度就不会一样了。但不管是哪种 loss,到 STE 这一步就仿佛一套组合拳打在棉花上,最重要的梯度信息都扔掉了。这里面的原因就在于 STE 根本无法体现量化的损失。在低比特量化的时候,这种副作用尤其明显 (所以 QAT 在低比特训练中尤其困难,模型权重根本训不动)。

DSQ

基本思想

为了解决这个问题,一个很直接的想法是用某个可导的函数来近似 round,从而避免使用 STE。

比如说,我们知道傅立叶级数可以近似任何周期函数:
请添加图片描述
(图片摘自:https://www.zhihu.com/search?q=%E5%82%85%E7%AB%8B%E5%8F%B6%E5%8F%98%E6%8D%A2%E4%B9%8B%E6%8E%90%E6%AD%BB%E6%95%99%E7%A8%8B&utm_content=search_suggestion&type=content)

如果把 round 当成一个周期函数,那我们就可以用傅立叶级数来逼近 round 了,而傅立叶级数是可以求导的。

或者,我们也可以对 round 函数进行泰勒展开,用多项式来近似。

又或者,我们知道神经网络本身可以模拟任何函数,因此甚至可以用一个神经网络来近似 round。

不过,以上这些想法都过于复杂,计算量巨大,操作起来比较困难。

而 DSQ 做的就是引入一个相对简单的函数来模拟 round,做到计算简单,同时尽可能逼近 round 函数。

这个

  • 4
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值