FakeLinearQuantization:定义如下,该module做的事是 对输入做伪量化;具体细节包括 训练过程确定 激活值的范围并更新scale、zp(infer则直接使用训练过程最后的scale、zp);使用 LinearQuantizeSTE(straight-through-estimator) 实现伪量化; class FakeLinearQuantization(nn.Module):
def __init__(self, num_bits=8, mode=LinearQuantMode.SYMMETRIC, ema_decay=0.999, dequantize=True, inplace=False):
"""
:param num_bits:
:param mode:
:param ema_decay: 激活值范围使用EMA进行跟踪
:param dequantize:
:param inplace:
"""
super(FakeLinearQuantization, self).__init__()
self.num_bits = num_bits
self.mode = mode
self.dequantize = dequantize
self.inplace = inplace
# We track activations ranges with exponential moving average, as proposed by Jacob et al., 2017
# https://arxiv.org/abs/1712.05877(激活值范围使用EMA进行跟踪)
# We perform bias correction on the EMA, so we keep both unbiased and biased values and the iterations count
# For a simple discussion of this see here:
# https://www.coursera.org/lecture/deep-neural-network/bias-correction-in-exponentially-weighted-averages-XjuhD
self.register_buffer('ema_decay', torch.tensor(ema_decay)) # 设置buffer,buffer用于非参的存储,会存于model state_dict
self.register_buffer('tracked_min_biased', torch.zeros(1))
self.register_buffer('tracked_min', torch.zeros(1)) # 保存无偏值
self.register_buffer('tracked_max_biased', torch.zeros(1)) # 保存有偏值
self.register_buffer('tracked_max', torch.zeros(1))
self.register_buffer('iter_count', torch.zeros(1)) # 保存迭代次数
self.register_buffer('scale', torch.ones(1))
self.register_buffer('zero_point', torch.zeros(1))
def forward(self, input):
# We update the tracked stats only in training
#
# Due to the way DataParallel works, we perform all updates in-place so the "main" device retains
# its updates. (see https://pytorch.org/docs/stable/nn.html#dataparallel)
# However, as it is now, the in-place update of iter_count causes an error when doing
# back-prop with multiple GPUs, claiming a variable required for gradient calculation has been modified
# in-place. Not clear why, since it's not used in any calculations that keep a gradient.
# It works fine with a single GPU. TODO: Debug...
if self.training: # 训练阶段要收集收据
with torch.no_grad():
current_min, current_max = get_tensor_min_max(input) # input是激活函数输出值
self.iter_count += 1
# 有偏值为正常加权值,无偏值为 有偏值/(1-decay**step)
self.tracked_min_biased.data, self.tracked_min.data = update_ema(self.tracked_min_biased.data,
current_min, self.ema_decay,
self.iter_count)
self.tracked_max_biased.data, self.tracked_max.data = update_ema(self.tracked_max_biased.data,
current_max, self.ema_decay,
self.iter_count)
if self.mode == LinearQuantMode.SYMMETRIC:
max_abs = max(abs(self.tracked_min), abs(self.tracked_max))
actual_min, actual_max = -max_abs, max_abs
if self.training: # 激活值的范围数值经EMA更新后需要重新计算scale和zp
self.scale.data, self.zero_point.data = symmetric_linear_quantization_params(self.num_bits, max_abs)
else:
actual_min, actual_max = self.tracked_min, self.tracked_max
signed = self.mode == LinearQuantMode.ASYMMETRIC_SIGNED
if self.training: # 激活值的范围数值经EMA更新后需要重新计算scale和zp
self.scale.data, self.zero_point.data = asymmetric_linear_quantization_params(self.num_bits,
self.tracked_min,
self.tracked_max,
signed=signed)
input = clamp(input, actual_min.item(), actual_max.item(), False)
# 执行量化、反量化操作,并且该过程无需额外梯度
input = LinearQuantizeSTE.apply(input, self.scale, self.zero_point, self.dequantize, False)
return input