我们知道,P,Q两列数据的相对熵越小,那么P,Q分布越接近,用Q近似P损失的信息就少,英伟达的INT8量化就是基于这个原理,如图是英伟达int8量化的算法伪代码
下面是根据相对熵来选取最佳阈值的代码。
import numpy as np
import copy
def compute_kl_divergence(P,Q):
length=len(P)
sum=0.0
for i in range(length):
if P[i]!=0:
if Q[i]==0:
sum+=1
else:
sum+=P[i]*np.log(P[i]/Q[i])
return sum
def threshold_distribution(distribution,target_bin):
target_threshold = target_bin
min_kl_divergence = 10000000000000
length = len(distribution)
for threshold in range(target_bin,length):
#t_distribution=np.empty((threshold,))
t_distribution=copy.deepcopy(distribution[0:threshold])
t_distribution[threshold - 1] += np.sum(distribution[threshold:])
#get P
num_per_bin = threshold / target_bin
quantize_distribution = np.zeros((target_bin,))
for i in range(target_bin):
start = i * num_per_bin
end = start + num_per_bin
left_upper = int(np.ceil(start))
if left_upper > start:
left_scale = left_upper - start
quantize_distribution[i] += left_scale * distribution[left_upper - 1]
right_lower = int(np.floor(end))
if right_lower < end:
right_scale = end - right_lower
quantize_distribution[i] += right_scale * distribution[right_lower]
for j in range(left_upper,right_lower):
quantize_distribution[i] += distribution[j]
# get Q
expand_distribution=np.zeros_like(t_distribution)
for i in range(target_bin):
start = i * num_per_bin
end = start + num_per_bin
count = 0
left_upper = int(np.ceil(start))
left_scale = 0
if left_upper > start:
left_scale = left_upper - start
if t_distribution[left_upper - 1] != 0:
count += left_scale
right_lower = int(np.floor(end))
right_scale = 0
if right_lower < end:
right_scale = end - right_lower
if t_distribution[right_lower] != 0:
count += right_scale
for j in range(left_upper,right_lower):
if t_distribution[j] != 0:
count+=1
expand_value = quantize_distribution[i] / count
if left_upper > start:
if t_distribution[left_upper - 1] != 0:
expand_distribution[left_upper - 1] += expand_value * left_scale
if right_lower < end:
if t_distribution[right_lower] != 0:
expand_distribution[right_lower] += expand_value * right_scale
for j in range(left_upper,right_lower):
if t_distribution[j] != 0:
expand_distribution[j] += expand_value
kl_divergence = compute_kl_divergence(t_distribution, expand_distribution)
#print(threshold,kl_divergence)
if kl_divergence < min_kl_divergence:
min_kl_divergence = kl_divergence
target_threshold = threshold
return target_threshold
if __name__=='__main__':
distribution=np.empty((2048,))
for i in range(len(distribution)):
distribution[i]=i
distribution/=np.sum(distribution)
target_threshold=threshold_distribution(distribution,128)
print(target_threshold)