量化感知训练 —— LSQ(二:代码解析)

注:

        在上一篇 《量化感知训练 —— LSQ(一:原理)》中 分享了对LSQ算法的原理的理解。本篇将进一步分享对LSQ代码的理解。上一篇如有对算法理解不对的地方,还请见谅,也欢迎各位阅读者评论区指正与讨论。笔者愿这些分享能真正地能为读者朋友提供一些帮助,唯恐自己水平不够,理解不对,而误人子弟。希望读者朋友能阅读LSQ算法以及SLQ+的原文, 这样能有自己对算法的理解,如果对知识有疑问,建议应多思考,多途径去寻找答案,而非仅限于某篇博客的分享。唯如此,不至于由笔者笔误或者理解错误导致读者朋友产生误解。

LSQ原文:[1902.08153] Learned Step Size Quantization (arxiv.org)

        对一个算法的理解过程一般会有三个过程:

                语文描述 ——> 数学表达 ——> 代码实现

        代码实现其实就是把数学公式用代码进行表达。量化感知训练的本质就是在训练过程中插入伪量化节点,通过伪量化节点的 “量化 —— 反量化” 过程来学习量化参数。在实现代码的过程中,一般会有以下问题:        

        1. 怎么在原模型中实现伪量化节点的插入?

        2. 伪量化节点的量化与反量化怎么实现?

        3. 怎么通过伪量化节点的参数更新?

         我们以逐层量化的方式来进行伪量化,对每一层的参数进行伪量化,那么就需要先定义每一层算子对应的伪量化算子。比如对于nn.Conv2d, 需要定义一个对于的QuanConv2d的伪量化算子,该算子既能实现nn.Conv2d的功能,又能实现伪量化。以nn.Conv2d为例,具体的参考代码如下:

import mpmath
import torch
import torch.nn as nn
import torch.nn.functional as F


class QuanConv2d(nn.Conv2d):
    def __init__(self,
                 m: nn.Conv2d,    # 初始化,需要进行伪量化的算子
                 quan_a_fn=None,  # 激活伪量化方法,通过该方法对激活值进行伪量化
                 quan_w_fn=None,  # 权重伪量化方法,通过该方法对权重进行伪量化
                 quan_b_fn=None): # 偏置伪量化方法,通过该方法对偏置进行伪量化
        assert isinstance(m, nn.Conv2d), "Input Operation should be nn.Conv2d"
        super(QuanConv2d, self).__init__(
            m.in_channels,
            m.out_channels,
            m.kernel_size,
            m.stride,
            m.padding,
            m.bias
        )   # 继承对应的nn.Conv2d的参数
        self.quan_w_fn = quan_w_fn
        self.quan_a_fn = quan_a_fn
        self.quan_b_fn = quan_b_fn
        self.weight = nn.Parameter(m.weight.detach())
        if m.bias:
            self.bias = nn.Parameter(m.bias.detach())
        else:
            self.bias = None

    def forward(self, x): 
        if self.quan_b_fn is not None:
            quantized_bias = self.quan_b_fn(self.bias) if \
                self.bias is not None else None # 对偏置进行伪量化
        else:
            quantized_bias = self.bias

        if self.quan_w_fn is not None:
            quantized_weight = self.quan_w_fn(self.weight)  # 对权重进行伪量化

            # 伪量化后的前向计算
            out = F.conv2d(x, quantized_weight, quantized_bias,
                           self.kernel_size, self.stride, self.padding) 
        else:
            out = F.conv2d(x, self.weight, quantized_bias, self.kernel_size,
                           self.stride, self.padding)
        
        if self.quan_a_fn is not None:
            out = self.quan_a_fn(out) # 对卷积后的激活值进行伪量化
        return out

        以上代码块,定义了一个nn.Conv2d对应的伪量化算子,通过这个伪量化算子,就能实现Conv2d的前向计算,以及伪量化。在这个伪量化算子中,其中还有伪量化方法需要作为输入传入进行初始化的,而具体的伪量化,以及需要学习的量化参数的更新方法,则通过伪量化方法实现。对于其他的算子,比如nn.Linear,则使用类似的定义实现。

        这里我们再根据LSQ算法的数学原理,来实现伪量化过程,具体参考代码如下:

import torch
import torch.nn as nn


def grad_scale(x, scale):  # 其中x为输入的学习参数s, scale为对参数s做缩放的缩放系数,
                           # 论文中设置该系数是有利于参数更新
    y = x
    y_grad = x * scale   # y_grad用于参与梯度计算
    # 返回的实际值是x的值,但是用于计算梯度的却是 x * scale
    return (y - y_grad).detach() + y_grad


def round_pass(x):
    # x为输入的weight或激活值,先对x取整
    y = x.round()
    # y_grad用来计算梯度,这里通过直通估计(STE)得到,直接取取整前的值用来进行梯度计算
    y_grad = x
    
    # 返回的值是取整的值,但是用于梯度更新的是y_grad
    return (y - y_grad).detach() + y_grad


class LsqQuan(nn.Module):
    def __init__(self, bit, all_positive=False, symmetric=False, per_channel=False):
        super().__init__()
        if all_positive:
            assert not symmetric, "Positive quantization cannot be symmetric"
            # 无符号量化范围: [0, 2^bit -1], 如当bit=8, 范围为: [0, 255]
            self.thd_neg = 0
            self.thd_pos = 2 ** bit - 1
        else:
            # 有符号量化, 如当bit=8, 范围为: [-128, 127]
            self.thd_neg = - 2 ** (bit - 1)
            self.thf_pos = 2 ** (bit - 1) - 1
        self.per_channel = per_channel
        self.s = nn.Parameter(torch.tensor(1.0))

    def init_from(self, x, *args, **kwargs):
        if self.per_channel:
            self.s = nn.Parameter(x.detach().abs().mean(dim=list(range(1, x.dim())),
                                                        keepdim=True) * 2 / self.thd_pos ** 0.5)
        else:
            self.s = nn.Parameter(x.detach().abs().mean() * 2 / (self.thd_pos ** 0.5))
    
    def forward(self, x):
        s_grad_scale = 1.0 / ((self.thd_pos * x.numel()) ** 0.5)
        s_scale = grad_scale(self.s, s_grad_scale) # 获取用于计算的scale, 并对乘上系数的scale取梯度
        
        x = x / s_scale
        x = torch.clamp(x, self.thd_neg, self.thd_pos)
        x = round_pass(x) # 量化, 取整, 并对取整后的值取梯度
        x = x * s_scale # 反量化
        return x

        至此,我们实现了伪量化算子,伪量化算子中利用伪量化方法进行参数伪量化,之后又定义了伪量化的类来实现伪量化方法。之后需要做的是,就是把原模型中的各个算子,替换成这里定义好的对应的伪量化算子,这样我们就可以实现一个插入LSQ伪量化节点的模型了,然后利用这个模型进行正常的训练,就可以进行LSQ量化感知训练了。之后的分享再进一步介绍怎么把普通的深度学习模型替换为LSQ量化模型。

  • 6
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值