摘要:在执行PTQ或者QAT量化时确定FeatureMap的min,max值有很多种方式,其它方式这里不做赘述,主要介绍TI的edgevision中QAT相关实现代码使用的Percentile这种方式来确定FeatureMap的min,max,理解感悟作为记录。
工程代码中确定min,max的入口函数如下:
x_min, x_max = utils.extrema_fast(x, range_shrink_percentile=True, fast_mode=False)
对应实现代码如下:
def extrema_fast(src, range_shrink_percentile=0.0, channel_mean=False, sigma=0.0, fast_mode=True):
return extrema(src, range_shrink_percentile, channel_mean, sigma, fast_mode)```
extrema函数实现如下:
def extrema(src, range_shrink_percentile=0.0, channel_mean=False, sigma=0.0, fast_mode=False):
if range_shrink_percentile == 0 and sigma == 0 and channel_mean == False:
mn = src.min()
mx = src.max()
return mn, mx
elif range_shrink_percentile:
# downsample for fast_mode
hist_array, mn, mx, mult_factor, offset = tensor_histogram(src, fast_mode=fast_mode)
if hist_array is None:
return mn, mx
new_mn_scaled, new_mx_scaled = extrema_hist_search(hist_array, range_shrink_percentile)
new_mn = (new_mn_scaled / mult_factor) + offset
new_mx = (new_mx_scaled / mult_factor) + offset
# take care of floating point inaccuracies that can
# increase the range (in rare cases) beyond the actual range.
new_mn = max(mn, new_mn)
new_mx = min(mx, new_mx)
return new_mn, new_mx
elif channel_mean:
dim = [0,2,3] if src.dim() == 4 else None
mn = torch.amin(src, dim=dim, keepdim=False).mean()
mx = torch.amax(src, dim=dim, keepdim=False).mean()
return mn, mx
elif sigma:
mean = torch.mean(src)
std = torch.std(src)
mn = mean - sigma*std
mx = mean + sigma*std
return mn, mx
else:
assert False, 'unknown extrema computation mode'
其中的核心是tensor_histogram和extrema_hist_search的函数实现:
def tensor_histogram(src, fast_mode=False):
# downsample for fast_mode
fast_stride = 2
fast_stride2 = fast_stride * 2
if fast_mode and len(src.size()) == 4 and (src.size(2) > fast_stride2) and (src.size(3) > fast_stride2):
r_start = random.randint(0, fast_stride - 1)
c_start = random.randint(0, fast_stride - 1)
src = src[..., r_start::fast_stride, c_start::fast_stride]
#
mn = src.min()
mx = src.max()
if mn == 0 and mx == 0:
return None, mn, mx, 1.0, 0.0
#
# compute range_shrink_percentile based min/max
# frequency - bincount can only operate on unsigned
num_bins = 255.0
cum_freq = float(100.0)
offset = mn
range_val = torch.abs(mx - mn)
mult_factor = (num_bins / range_val)
tensor_int = (src.contiguous().view(-1) - offset) * mult_factor
tensor_int = functional.round_g(tensor_int).int()
# numpy version
# hist = np.bincount(tensor_int.cpu().numpy())
# hist_sum = np.sum(hist)
# hist_array = hist.astype(np.float32) * cum_freq / float(hist_sum)
# torch version
hist = torch.bincount(tensor_int) # calculate appearing number for each element in tensor
hist_sum = torch.sum(hist)
hist = hist.float() * cum_freq / hist_sum.float()
hist_array = hist.cpu().numpy()
return hist_array, mn, mx, mult_factor, offset
# this code is not parallelizable. better to pass a numpy array
def extrema_hist_search(hist_array, range_shrink_percentile):
new_mn_scaled = 0
new_mx_scaled = len(hist_array) - 1
hist_sum_left = 0.0
hist_sum_right = 0.0
for h_idx in range(len(hist_array)):
r_idx = len(hist_array) - 1 - h_idx
hist_sum_left += hist_array[h_idx]
hist_sum_right += hist_array[r_idx]
if hist_sum_left < range_shrink_percentile:
new_mn_scaled = h_idx
if hist_sum_right < range_shrink_percentile:
new_mx_scaled = r_idx
#
#
return new_mn_scaled, new_mx_scaled
代码实现就是这些,且看每一步都在做什么。
入口代码且不必说,主要看最后的核心实现函数,其中src参数就是featuremap入参进去,也就是Tensor矩阵进来,mn=src.min() 和mx=src.max()作为基础的min,max值, 只要mn、mx不为0那就开始使用percentile这种方式来确定mn、mx参数:
假设输入src=torch.Tensor(np.random.randn(5, 5, 5)),通过如下方式将所有点映射到[0,255], 推测,这里使用255主要是使用8bit量化方式。
num_bins = 255.0
cum_freq = float(100.0)
offset = mn
range_val = torch.abs(mx - mn)
mult_factor = (num_bins / range_val)
tensor_int = (src.contiguous().view(-1) - offset) * mult_factor
tensor_int = functional.round_g(tensor_int).int()
而后取整并统计tensor_int在0-255上每个整数值出现的次数,这里可以看如下测试获得相应结果:
src=tensor([[[ 1.2441, -0.0867, 0.9002, -0.1932, 2.3541],
[-0.4824, -0.3223, 2.5375, -1.7840, 1.9771],
[ 0.3940, -1.4480, -0.3781, -0.7957, -1.0538],
[-1.2109, -0.3825, 0.9433, -1.1705, 0.3820],
[ 2.3790, 0.0455, 1.0485, -1.2459, 0.1798]],
[[-0.8588, 1.2191, 1.4201, 0.6451, 0.1503],
[-0.1744, 0.3837, -0.7574, -0.6267, -0.8850],
[ 1.2707, 0.0152, -0.8123, 1.0253, -1.2014],
[ 0.4145, -1.8473, -1.1998, -1.1097, -0.2954],
[ 1.3490, 0.3272, -0.7207, 0.8277, -1.4434]],
[[-0.4233, -0.5626, -0.5374, 0.5536, 0.7213],
[ 2.1416, 0.1454, 1.3634, 1.3695, 0.2687],
[ 0.0225, 0.6928, 0.5589, -0.1421, -1.0773],
[-0.9817, 1.0085, 1.0390, 0.6374, -0.4628],
[-0.4807, -0.8651, 0.9142, -0.1464, 0.4845]],
[[-0.5339, 1.2211, 0.5690, -0.3386, -0.8696],
[-1.6646, -0.4744, 2.2127, 0.1378, -0.3506],
[-0.4366, 0.9862, -1.0315, -0.2987, -0.7905],
[ 0.8619, 0.1890, 0.2300, 1.4928, -0.3998],
[ 0.3227, -1.3765, -0.7497, -0.0803, -1.1507]],
[[ 0.1519, -0.8967, 0.3365, -0.2855, -0.2702],
[ 1.1136, 0.3253, 2.6808, 0.4749, -1.1373],
[-1.4715, -1.5497, -0.5941, -0.3381, -0.1578],
[ 0.1751, 0.4058, 0.4405, 1.0508, -0.8351],
[ 1.3446, 1.4361, 1.0363, -0.2064, 1.6912]]])
hist=tensor([1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1,
0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 3, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0,
0, 1, 0, 0, 0, 0, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1,
1, 0, 2, 0, 0, 3, 1, 1, 1, 0, 2, 1, 1, 2, 1, 2, 1, 1, 0, 0, 1, 1, 1, 1,
2, 0, 0, 1, 1, 0, 0, 0, 0, 2, 0, 1, 0, 0, 0, 0, 3, 1, 2, 1, 0, 1, 0, 1,
0, 0, 3, 1, 0, 0, 3, 2, 0, 1, 0, 2, 0, 0, 0, 1, 2, 0, 0, 0, 2, 0, 0, 1,
0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 2, 3, 0, 0, 0, 1,
0, 0, 0, 0, 0, 2, 1, 0, 1, 0, 0, 0, 2, 2, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1])
之后对hist再做一次缩放的计算,就得到来hist_arrray矩阵,该矩阵会对mx、mn的更新起到决定性作用。
hist = torch.bincount(tensor_int) # calculate appearing number for each element in tensor
hist_sum = torch.sum(hist)
hist = hist.float() * cum_freq / hist_sum.float()
hist_array = hist.cpu().numpy()
extrema_hist_search中进行mx、mn的更新, 从hist_array两端开始不断累积查找,寻找到大于range_shrink_percentile的index并返回,
通过tensor_histogram函数返回的mult_factor, offset计算新的new_mn、new_mx
new_mn = (new_mn_scaled / mult_factor) + offset
new_mx = (new_mx_scaled / mult_factor) + offset
与原来的mn和mx进行比较,从而更新mn和mx
new_mn = max(mn, new_mn)
new_mx = min(mx, new_mx)
return new_mn, new_mx
这里不太理解的其实是最后new_mn、new_mx的反结算方式,这算是一种更新方式,而不是tensor_histogram的逆运算。