量化参数的确定——ShrinkRange

本文详细介绍了TI的edgevision库中QuantizationAwareTraining(QAT)过程中,如何通过Percentile方法计算FeatureMap的最小值(min)和最大值(max),涉及的关键函数如extrema_fast、tensor_histogram和extrema_hist_search的实现过程。
摘要由CSDN通过智能技术生成

摘要:在执行PTQ或者QAT量化时确定FeatureMap的min,max值有很多种方式,其它方式这里不做赘述,主要介绍TI的edgevision中QAT相关实现代码使用的Percentile这种方式来确定FeatureMap的min,max,理解感悟作为记录。

工程代码中确定min,max的入口函数如下:

x_min, x_max = utils.extrema_fast(x, range_shrink_percentile=True, fast_mode=False)

对应实现代码如下:

def extrema_fast(src, range_shrink_percentile=0.0, channel_mean=False, sigma=0.0, fast_mode=True):
    return extrema(src, range_shrink_percentile, channel_mean, sigma, fast_mode)```

extrema函数实现如下:

def extrema(src, range_shrink_percentile=0.0, channel_mean=False, sigma=0.0, fast_mode=False):
    if range_shrink_percentile == 0 and sigma == 0 and channel_mean == False:
        mn = src.min()
        mx = src.max()
        return mn, mx
    elif range_shrink_percentile:
        # downsample for fast_mode
        hist_array, mn, mx, mult_factor, offset = tensor_histogram(src, fast_mode=fast_mode)
        if hist_array is None:
            return mn, mx

        new_mn_scaled, new_mx_scaled = extrema_hist_search(hist_array, range_shrink_percentile)
        new_mn = (new_mn_scaled / mult_factor) + offset
        new_mx = (new_mx_scaled / mult_factor) + offset

        # take care of floating point inaccuracies that can
        # increase the range (in rare cases) beyond the actual range.
        new_mn = max(mn, new_mn)
        new_mx = min(mx, new_mx)
        return new_mn, new_mx
    elif channel_mean:
        dim = [0,2,3] if src.dim() == 4 else None
        mn = torch.amin(src, dim=dim, keepdim=False).mean()
        mx = torch.amax(src, dim=dim, keepdim=False).mean()
        return mn, mx
    elif sigma:
        mean = torch.mean(src)
        std = torch.std(src)
        mn = mean - sigma*std
        mx = mean + sigma*std
        return mn, mx
    else:
        assert False, 'unknown extrema computation mode'

其中的核心是tensor_histogram和extrema_hist_search的函数实现:

def tensor_histogram(src, fast_mode=False):
    # downsample for fast_mode
    fast_stride = 2
    fast_stride2 = fast_stride * 2
    if fast_mode and len(src.size()) == 4 and (src.size(2) > fast_stride2) and (src.size(3) > fast_stride2):
        r_start = random.randint(0, fast_stride - 1)
        c_start = random.randint(0, fast_stride - 1)
        src = src[..., r_start::fast_stride, c_start::fast_stride]
    #
    mn = src.min()
    mx = src.max()
    if mn == 0 and mx == 0:
        return None, mn, mx, 1.0, 0.0
    #

    # compute range_shrink_percentile based min/max
    # frequency - bincount can only operate on unsigned
    num_bins = 255.0
    cum_freq = float(100.0)
    offset = mn
    range_val = torch.abs(mx - mn)
    mult_factor = (num_bins / range_val)
    tensor_int = (src.contiguous().view(-1) - offset) * mult_factor
    tensor_int = functional.round_g(tensor_int).int()

    # numpy version
    # hist = np.bincount(tensor_int.cpu().numpy())
    # hist_sum = np.sum(hist)
    # hist_array = hist.astype(np.float32) * cum_freq / float(hist_sum)

    # torch version
    hist = torch.bincount(tensor_int) # calculate appearing number for each element in tensor
    hist_sum = torch.sum(hist)
    hist = hist.float() * cum_freq / hist_sum.float()
    hist_array = hist.cpu().numpy()
    return hist_array, mn, mx, mult_factor, offset


# this code is not parallelizable. better to pass a numpy array
def extrema_hist_search(hist_array, range_shrink_percentile):
    new_mn_scaled = 0
    new_mx_scaled = len(hist_array) - 1
    hist_sum_left = 0.0
    hist_sum_right = 0.0
    for h_idx in range(len(hist_array)):
        r_idx = len(hist_array) - 1 - h_idx
        hist_sum_left += hist_array[h_idx]
        hist_sum_right += hist_array[r_idx]
        if hist_sum_left < range_shrink_percentile:
            new_mn_scaled = h_idx
        if hist_sum_right < range_shrink_percentile:
            new_mx_scaled = r_idx
        #
    #
    return new_mn_scaled, new_mx_scaled

代码实现就是这些,且看每一步都在做什么。
入口代码且不必说,主要看最后的核心实现函数,其中src参数就是featuremap入参进去,也就是Tensor矩阵进来,mn=src.min() 和mx=src.max()作为基础的min,max值, 只要mn、mx不为0那就开始使用percentile这种方式来确定mn、mx参数:
假设输入src=torch.Tensor(np.random.randn(5, 5, 5)),通过如下方式将所有点映射到[0,255], 推测,这里使用255主要是使用8bit量化方式。

    num_bins = 255.0
    cum_freq = float(100.0)
    offset = mn
    range_val = torch.abs(mx - mn)
    mult_factor = (num_bins / range_val)
    tensor_int = (src.contiguous().view(-1) - offset) * mult_factor
    tensor_int = functional.round_g(tensor_int).int()

而后取整并统计tensor_int在0-255上每个整数值出现的次数,这里可以看如下测试获得相应结果:

src=tensor([[[ 1.2441, -0.0867,  0.9002, -0.1932,  2.3541],
         [-0.4824, -0.3223,  2.5375, -1.7840,  1.9771],
         [ 0.3940, -1.4480, -0.3781, -0.7957, -1.0538],
         [-1.2109, -0.3825,  0.9433, -1.1705,  0.3820],
         [ 2.3790,  0.0455,  1.0485, -1.2459,  0.1798]],

        [[-0.8588,  1.2191,  1.4201,  0.6451,  0.1503],
         [-0.1744,  0.3837, -0.7574, -0.6267, -0.8850],
         [ 1.2707,  0.0152, -0.8123,  1.0253, -1.2014],
         [ 0.4145, -1.8473, -1.1998, -1.1097, -0.2954],
         [ 1.3490,  0.3272, -0.7207,  0.8277, -1.4434]],

        [[-0.4233, -0.5626, -0.5374,  0.5536,  0.7213],
         [ 2.1416,  0.1454,  1.3634,  1.3695,  0.2687],
         [ 0.0225,  0.6928,  0.5589, -0.1421, -1.0773],
         [-0.9817,  1.0085,  1.0390,  0.6374, -0.4628],
         [-0.4807, -0.8651,  0.9142, -0.1464,  0.4845]],

        [[-0.5339,  1.2211,  0.5690, -0.3386, -0.8696],
         [-1.6646, -0.4744,  2.2127,  0.1378, -0.3506],
         [-0.4366,  0.9862, -1.0315, -0.2987, -0.7905],
         [ 0.8619,  0.1890,  0.2300,  1.4928, -0.3998],
         [ 0.3227, -1.3765, -0.7497, -0.0803, -1.1507]],

        [[ 0.1519, -0.8967,  0.3365, -0.2855, -0.2702],
         [ 1.1136,  0.3253,  2.6808,  0.4749, -1.1373],
         [-1.4715, -1.5497, -0.5941, -0.3381, -0.1578],
         [ 0.1751,  0.4058,  0.4405,  1.0508, -0.8351],
         [ 1.3446,  1.4361,  1.0363, -0.2064,  1.6912]]])
hist=tensor([1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1,
        0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 3, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0,
        0, 1, 0, 0, 0, 0, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1,
        1, 0, 2, 0, 0, 3, 1, 1, 1, 0, 2, 1, 1, 2, 1, 2, 1, 1, 0, 0, 1, 1, 1, 1,
        2, 0, 0, 1, 1, 0, 0, 0, 0, 2, 0, 1, 0, 0, 0, 0, 3, 1, 2, 1, 0, 1, 0, 1,
        0, 0, 3, 1, 0, 0, 3, 2, 0, 1, 0, 2, 0, 0, 0, 1, 2, 0, 0, 0, 2, 0, 0, 1,
        0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 2, 3, 0, 0, 0, 1,
        0, 0, 0, 0, 0, 2, 1, 0, 1, 0, 0, 0, 2, 2, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
        0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1])

之后对hist再做一次缩放的计算,就得到来hist_arrray矩阵,该矩阵会对mx、mn的更新起到决定性作用。

    hist = torch.bincount(tensor_int) # calculate appearing number for each element in tensor
    hist_sum = torch.sum(hist)
    hist = hist.float() * cum_freq / hist_sum.float()
    hist_array = hist.cpu().numpy()

extrema_hist_search中进行mx、mn的更新, 从hist_array两端开始不断累积查找,寻找到大于range_shrink_percentile的index并返回,
通过tensor_histogram函数返回的mult_factor, offset计算新的new_mn、new_mx

	new_mn = (new_mn_scaled / mult_factor) + offset
	new_mx = (new_mx_scaled / mult_factor) + offset

与原来的mn和mx进行比较,从而更新mn和mx

    new_mn = max(mn, new_mn)
    new_mx = min(mx, new_mx)
    return new_mn, new_mx

这里不太理解的其实是最后new_mn、new_mx的反结算方式,这算是一种更新方式,而不是tensor_histogram的逆运算。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值