注:
在上一篇 《量化感知训练 —— LSQ(一:原理)》中 分享了对LSQ算法的原理的理解。本篇将进一步分享对LSQ代码的理解。上一篇如有对算法理解不对的地方,还请见谅,也欢迎各位阅读者评论区指正与讨论。笔者愿这些分享能真正地能为读者朋友提供一些帮助,唯恐自己水平不够,理解不对,而误人子弟。希望读者朋友能阅读LSQ算法以及SLQ+的原文, 这样能有自己对算法的理解,如果对知识有疑问,建议应多思考,多途径去寻找答案,而非仅限于某篇博客的分享。唯如此,不至于由笔者笔误或者理解错误导致读者朋友产生误解。
LSQ原文:[1902.08153] Learned Step Size Quantization (arxiv.org)
对一个算法的理解过程一般会有三个过程:
语文描述 ——> 数学表达 ——> 代码实现
代码实现其实就是把数学公式用代码进行表达。量化感知训练的本质就是在训练过程中插入伪量化节点,通过伪量化节点的 “量化 —— 反量化” 过程来学习量化参数。在实现代码的过程中,一般会有以下问题:
1. 怎么在原模型中实现伪量化节点的插入?
2. 伪量化节点的量化与反量化怎么实现?
3. 怎么通过伪量化节点的参数更新?
我们以逐层量化的方式来进行伪量化,对每一层的参数进行伪量化,那么就需要先定义每一层算子对应的伪量化算子。比如对于nn.Conv2d, 需要定义一个对于的QuanConv2d的伪量化算子,该算子既能实现nn.Conv2d的功能,又能实现伪量化。以nn.Conv2d为例,具体的参考代码如下:
import mpmath
import torch
import torch.nn as nn
import torch.nn.functional as F
class QuanConv2d(nn.Conv2d):
def __init__(self,
m: nn.Conv2d, # 初始化,需要进行伪量化的算子
quan_a_fn=None, # 激活伪量化方法,通过该方法对激活值进行伪量化
quan_w_fn=None, # 权重伪量化方法,通过该方法对权重进行伪量化
quan_b_fn=None): # 偏置伪量化方法,通过该方法对偏置进行伪量化
assert isinstance(m, nn.Conv2d), "Input Operation should be nn.Conv2d"
super(QuanConv2d, self).__init__(
m.in_channels,
m.out_channels,
m.kernel_size,
m.stride,
m.padding,
m.bias
) # 继承对应的nn.Conv2d的参数
self.quan_w_fn = quan_w_fn
self.quan_a_fn = quan_a_fn
self.quan_b_fn = quan_b_fn
self.weight = nn.Parameter(m.weight.detach())
if m.bias:
self.bias = nn.Parameter(m.bias.detach())
else:
self.bias = None
def forward(self, x):
if self.quan_b_fn is not None:
quantized_bias = self.quan_b_fn(self.bias) if \
self.bias is not None else None # 对偏置进行伪量化
else:
quantized_bias = self.bias
if self.quan_w_fn is not None:
quantized_weight = self.quan_w_fn(self.weight) # 对权重进行伪量化
# 伪量化后的前向计算
out = F.conv2d(x, quantized_weight, quantized_bias,
self.kernel_size, self.stride, self.padding)
else:
out = F.conv2d(x, self.weight, quantized_bias, self.kernel_size,
self.stride, self.padding)
if self.quan_a_fn is not None:
out = self.quan_a_fn(out) # 对卷积后的激活值进行伪量化
return out
以上代码块,定义了一个nn.Conv2d对应的伪量化算子,通过这个伪量化算子,就能实现Conv2d的前向计算,以及伪量化。在这个伪量化算子中,其中还有伪量化方法需要作为输入传入进行初始化的,而具体的伪量化,以及需要学习的量化参数的更新方法,则通过伪量化方法实现。对于其他的算子,比如nn.Linear,则使用类似的定义实现。
这里我们再根据LSQ算法的数学原理,来实现伪量化过程,具体参考代码如下:
import torch
import torch.nn as nn
def grad_scale(x, scale): # 其中x为输入的学习参数s, scale为对参数s做缩放的缩放系数,
# 论文中设置该系数是有利于参数更新
y = x
y_grad = x * scale # y_grad用于参与梯度计算
# 返回的实际值是x的值,但是用于计算梯度的却是 x * scale
return (y - y_grad).detach() + y_grad
def round_pass(x):
# x为输入的weight或激活值,先对x取整
y = x.round()
# y_grad用来计算梯度,这里通过直通估计(STE)得到,直接取取整前的值用来进行梯度计算
y_grad = x
# 返回的值是取整的值,但是用于梯度更新的是y_grad
return (y - y_grad).detach() + y_grad
class LsqQuan(nn.Module):
def __init__(self, bit, all_positive=False, symmetric=False, per_channel=False):
super().__init__()
if all_positive:
assert not symmetric, "Positive quantization cannot be symmetric"
# 无符号量化范围: [0, 2^bit -1], 如当bit=8, 范围为: [0, 255]
self.thd_neg = 0
self.thd_pos = 2 ** bit - 1
else:
# 有符号量化, 如当bit=8, 范围为: [-128, 127]
self.thd_neg = - 2 ** (bit - 1)
self.thf_pos = 2 ** (bit - 1) - 1
self.per_channel = per_channel
self.s = nn.Parameter(torch.tensor(1.0))
def init_from(self, x, *args, **kwargs):
if self.per_channel:
self.s = nn.Parameter(x.detach().abs().mean(dim=list(range(1, x.dim())),
keepdim=True) * 2 / self.thd_pos ** 0.5)
else:
self.s = nn.Parameter(x.detach().abs().mean() * 2 / (self.thd_pos ** 0.5))
def forward(self, x):
s_grad_scale = 1.0 / ((self.thd_pos * x.numel()) ** 0.5)
s_scale = grad_scale(self.s, s_grad_scale) # 获取用于计算的scale, 并对乘上系数的scale取梯度
x = x / s_scale
x = torch.clamp(x, self.thd_neg, self.thd_pos)
x = round_pass(x) # 量化, 取整, 并对取整后的值取梯度
x = x * s_scale # 反量化
return x
至此,我们实现了伪量化算子,伪量化算子中利用伪量化方法进行参数伪量化,之后又定义了伪量化的类来实现伪量化方法。之后需要做的是,就是把原模型中的各个算子,替换成这里定义好的对应的伪量化算子,这样我们就可以实现一个插入LSQ伪量化节点的模型了,然后利用这个模型进行正常的训练,就可以进行LSQ量化感知训练了。之后的分享再进一步介绍怎么把普通的深度学习模型替换为LSQ量化模型。