# 个人 解析“Weighted-Entropy-based Quantization”

## 权重值量化

#!/usr/bin/env python3

import copy
import math as m
from functools import reduce

class QuantInfo:

def __init__(self, sum, slice, level, n):
self.sum = sum
self.slice = slice
self.level = level
self.n = n

def half_level(self):
return (self.level + 1) // 2

def p(self, n):
if n == 0:
return self.slice[0] / self.n
else:
return (self.slice[n] - self.slice[n - 1]) / self.n

def wp(self, n):
if n == 0:
index = self.slice[0] - 1
return self.sum[index] / self.n
else:
index1 = self.slice[n] - 1
index2 = self.slice[n - 1] - 1
return (self.sum[index1] - self.sum[index2]) / self.n

def w(self, n):
if n == 0:
index = self.slice[0] - 1
return self.sum[index] / self.slice[0]
else:
index1 = self.slice[n] - 1
index2 = self.slice[n - 1] - 1
return (self.sum[index1] - self.sum[index2]) / (self.slice[n] - self.slice[n - 1])

def wEnt(self):
ent = 0
for i in range(self.level % 2, self.half_level()):
ent -= self.wp(i) * m.log(self.p(i) / 2)

return ent

def pEnt(self, n):
ent = 0
for i in range(max(n, self.level % 2), n+2):
ent -= self.wp(i) * m.log(self.p(i) / 2)

return ent
# end class

def recUpdate(
n,
info,
begin = 0,
end = 0):

if begin == 0 and end == 0:
if n == 0:
new_begin = 0
new_end = info.slice[1] - 1
else:
new_begin = info.slice[n-1]
new_end = info.slice[n+1] - 1

recUpdate(n, info, new_begin, new_end)

elif end - begin == 1:
info.slice[n] = begin
b_ent = info.pEnt(n)
info.slice[n] = end
e_ent = info.pEnt(n)

info.slice[n] = begin if b_ent > e_ent else end
else:
center = (begin + end) // 2
info.slice[n] = center
f_ent = info.pEnt(n)
info.slice[n] = center + 1
b_ent = info.pEnt(n)

if f_ent > b_ent:
recUpdate(n, info, begin, center)
else:
recUpdate(n, info, center, end)
# end recUpdate

def weight_quant_scale_forward(
half_level,
len_ch,
num_ch,
#scale_list,
weight_list,
quant_bdr,
quant_rep):

out = []
bin_out = []
for ch in range(num_ch):
for idx in range(len_ch):
cur_weight = weight_list[ch * len_ch + idx]
sign = 1 if cur_weight >= 0 else -1
value = cur_weight * sign

output_temp = 0
bin_index = 0
for j in range(half_level):
if value >= quant_bdr[j]:
output_temp = quant_rep[j]
bin_index = j

output_value = output_temp * sign
out.append(output_value)
bin_out.append(bin_index * sign)

return (out, bin_out)

#end weight_quant_scale_forward

def eh_weight_quant_update(
num_level,
num_ch,
len_ch,
in_weight,
slice):

half_level = (num_level+1) // 2;
n = len(in_weight)

#scale_list = []
#for i in range(num_ch):
#  scale_list.append(1.0)

importance_list = [idx ** 2 for idx in in_weight]

importance_list.sort()
importance_sorted_list = importance_list

suffix_sum_list = []
for i in range(1, n+1):
suffix_sum_list.append(reduce(lambda x, y: x+ y, importance_sorted_list[0:i]))

#print(suffix_sum_list)

slice_vec = []
for i in range(half_level):
slice_vec.append(copy.deepcopy(slice[i]))

#print(slice_vec)

info = QuantInfo(
copy.deepcopy(importance_sorted_list),
slice_vec,
num_level,
n)

new_ent = info.wEnt()
prev_ent = 0.

while new_ent > prev_ent:
for i in range(half_level - 1):
recUpdate(i, info)
prev_ent = new_ent
new_ent = info.wEnt()

#print(slice_vec)

rep = [0 for i in range(half_level)]
bdr = [0 for i in range(half_level)]

for i in range(num_level % 2, half_level):
rep[i] = m.sqrt(info.w(i))

for i in range(half_level - 1):
bdr[i + 1] = m.sqrt(
importance_sorted_list[slice_vec[i]] -
importance_sorted_list[slice_vec[i] - 1])

print("[scale factor]     rep is {}".format(rep))
print("[log distribution] bdr is {}".format(bdr))

out, bin_out = weight_quant_scale_forward(
half_level,
len_ch,
num_ch,
#scale_list,
weight_list,
bdr,
rep)

print("[bin ] {}".format(bin_out))
print("[real] {}".format(out))

#end eh_weight_quant_update

if __name__ == '__main__':

quant_level = 5
number_ch = 1
weight_list = [0.3, 1.4, 2.7, -3.1, 0.8, -1.1, 2.3, 0.0, 4.1]
length_ch = len(weight_list) // number_ch
half_level = (quant_level + 1) // 2

slice = [ int(len(weight_list)*(idx+1)/half_level) for idx in range(half_level)]

print("quant level is {}".format(quant_level))
print(weight_list)

eh_weight_quant_update(
quant_level,
number_ch,
length_ch,
weight_list,
slice)

## 激活值量化

#!/usr/bin/env python3

import math as m
import copy
import sys
from functools import reduce

BASE_SIZE = 16
BIN_SIZE  = 1024

def eh_histogram(
active_list,
offset):

hist = [0 for idx in range(BIN_SIZE)]

for index in range(len(active_list)):

cur_active = active_list[index]
#print("cur_active is {}".format(cur_active))

if cur_active > 0:

quant = m.floor(m.log2(cur_active) * BASE_SIZE) - offset
#print("quant is {}".format(quant))

if quant < 0:
hist[0] += 1
elif quant >= BIN_SIZE:
hist[BIN_SIZE - 1] += 1
else:
hist[quant] += 1

#print(hist)

return hist

# end eh_histogram

class LQInfo:

def __init__(self, num_level, num_weight, offset, hist_scan):
self.num_level = num_level
self.num_weight = num_weight
self.offset = offset
self.hist_scan = hist_scan

def p(self, idx, fst, step):

if idx == self.num_level - 1:
begin = fsr + (step >> 1) * (2 * idx - 3)
temp = self.num_weight - self.hist_scan[begin - 1]
return temp / self.num_weight
else:
begin = fsr + (step >> 1) * (2 * idx - 3)
end = fsr + (step >> 1) * (2 * idx - 1)
temp = self.hist_scan[end - 1] - self.hist_scan[begin - 1]
return temp /self.num_weight

def w(self, idx, fsr, step):
temp = (fsr + (idx - 1) * step + self.offset) / BIN_SIZE
return m.pow(2, temp)

def wEnt(self, fsr, step):

ent = 0

for i in range(1, self.num_level):
prob = self.p(i, fsr, step)

if prob > 0:
ent -= self.w(i, fsr, step) * prob * m.log(prob)

return ent

# end LQInfo

def WLQReLUForward2(
num_level,
active_list,
offset,
step,
train):

float_out = []
quant_out = []
sign_out = []

for index in range(0, len(active_list)):

oTemp = 0
in_data = active_list[index]

sign = 1
if in_data < 0:
sign = -1

in_data *= sign

sign_out.append(0 if sign > 0 else 1)

if in_data > 0:

temp = round((m.log2(in_data) * BASE_SIZE - offset) / float(step + 10e-10))
#print(temp)
#mod_idx = min(num_level - 2, temp)
# Reserve 2bit for isZero and isPositive number
mod_idx = min(num_level-1, temp)
quant_out.append(mod_idx)

if mod_idx < 0:
oTemp =  sys.float_info.min if train == 0 else 0.
else:
temp = (offset + mod_idx * step) // BASE_SIZE
oTemp = m.pow(2, temp)
quant_out[index] = temp

float_out.append(sign * oTemp)

return (float_out, quant_out, sign_out)

if __name__ == '__main__':
# only first picture do this\

num_level = 4
offset = -640
active_list = [0.3, 1.4, 2.7, -3.1, 0.8, -14.1, 2.3, 0.1, 14.1]
#active_list = [1.317857, 0.317857, 2.530624, 1.530624, 0.294160, 3.294160, 10.317857]
print("active_list is {}".format(active_list))

if True:

hist = eh_histogram(
active_list,
offset)

#print(hist)

suffix_sum_list = []
for i in range(1, len(hist) + 1):
suffix_sum_list.append(reduce(lambda x, y: x+ y, hist[0:i]))

#print(suffix_sum_list)

min_offset = -1
max_offset = -1
max_value = -1

for i in range(BIN_SIZE):
if suffix_sum_list[i] == 0:
min_offset = i

if suffix_sum_list[i] > max_value:
max_offset = i
max_value = suffix_sum_list[i]

length = max_offset - min_offset + 1
#hist += min_offset
offset += min_offset
#print("length is {}".format(length))
#print("offset is {}".format(offset))

max_ent = 0
max_half_step = 0
max_fsr = 0

suffix_sum_index = min_offset + length - 1

info = LQInfo(
num_level,
suffix_sum_list[suffix_sum_index],
offset,
suffix_sum_list[min_offset:])

for half_step in range(1, 17):
end_temp = length - half_step * (2 * num_level - 5)
for fsr in range(40, end_temp):
ent = info.wEnt(fsr, 2 * half_step)

if ent > max_ent:
max_ent = ent
max_half_step = half_step
max_fsr = fsr

bdr = [0 for idx in range(256)]
bdr[0] = bdr[num_level] = 0;

for i in range(1, num_level):
temp1 = (max_fsr + offset + 2 * (i - 1) * max_half_step)
bdr[num_level + i] = m.pow(2, temp1 / BASE_SIZE)
temp2 = (max_fsr + offset + (2 * i - 1) * max_half_step)
bdr[i] = m.pow(2, temp2 / BASE_SIZE)

bdr[num_level] = max_fsr + offset
bdr[0] = 2 * max_half_step

float_out, quant_out, sign_out = WLQReLUForward2(
num_level,
active_list,
bdr[num_level],
bdr[0],
1) # mean test

print("float output is {}".format(float_out))
print("quant output is {}".format(quant_out))
print("sign  output is {}".format(sign_out))
print("offset is {}".format(bdr[num_level]))
print("step is {}".format(bdr[0]))

