基于KL散度的INT8训练后量化

40 篇文章 6 订阅

我们知道,P,Q两列数据的相对熵越小,那么P,Q分布越接近,用Q近似P损失的信息就少,英伟达的INT8量化就是基于这个原理,如图是英伟达int8量化的算法伪代码
在这里插入图片描述

下面是根据相对熵来选取最佳阈值的代码。

import numpy as np
import copy

def compute_kl_divergence(P,Q):
    length=len(P)
    sum=0.0
    for i in range(length):
        if P[i]!=0:
            if Q[i]==0:
                sum+=1
            else:
                sum+=P[i]*np.log(P[i]/Q[i])
    return sum


def threshold_distribution(distribution,target_bin):
    target_threshold = target_bin
    min_kl_divergence = 10000000000000
    length = len(distribution)



    for threshold in range(target_bin,length):
        #t_distribution=np.empty((threshold,))
        t_distribution=copy.deepcopy(distribution[0:threshold])
        t_distribution[threshold - 1] += np.sum(distribution[threshold:])

        #get P
        num_per_bin = threshold / target_bin

        quantize_distribution = np.zeros((target_bin,))

        for i in range(target_bin):
            start = i * num_per_bin
            end = start + num_per_bin

            left_upper = int(np.ceil(start))
            if left_upper > start:
                left_scale = left_upper - start
                quantize_distribution[i] += left_scale * distribution[left_upper - 1]
            right_lower = int(np.floor(end))

            if right_lower < end:
                right_scale = end - right_lower
                quantize_distribution[i] += right_scale * distribution[right_lower]

            for j in range(left_upper,right_lower):
                quantize_distribution[i] += distribution[j]

        # get Q
        expand_distribution=np.zeros_like(t_distribution)

        for i in range(target_bin):
            start = i * num_per_bin
            end = start + num_per_bin

            count = 0

            left_upper = int(np.ceil(start))
            left_scale = 0
            if left_upper > start:
                left_scale = left_upper - start
                if t_distribution[left_upper - 1] != 0:
                    count += left_scale

            right_lower = int(np.floor(end))
            right_scale = 0
            if right_lower < end:
                right_scale = end - right_lower
                if t_distribution[right_lower] != 0:
                    count += right_scale

            for j in range(left_upper,right_lower):
                if t_distribution[j] != 0:
                    count+=1

            expand_value = quantize_distribution[i] / count

            if left_upper > start:
                if t_distribution[left_upper - 1] != 0:
                    expand_distribution[left_upper - 1] += expand_value * left_scale
            if right_lower < end:
                if t_distribution[right_lower] != 0:
                    expand_distribution[right_lower] += expand_value * right_scale
            for j in range(left_upper,right_lower):
                if t_distribution[j] != 0:
                    expand_distribution[j] += expand_value

        kl_divergence = compute_kl_divergence(t_distribution, expand_distribution)

        #print(threshold,kl_divergence)

        if kl_divergence < min_kl_divergence:
            min_kl_divergence = kl_divergence
            target_threshold = threshold

    return target_threshold


if __name__=='__main__':
    distribution=np.empty((2048,))
    for i in range(len(distribution)):
        distribution[i]=i
    distribution/=np.sum(distribution)
    target_threshold=threshold_distribution(distribution,128)
    print(target_threshold)

  • 0
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

FPGA硅农

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值