本节我们将基于pytorch来实测量化的实现,pytorch基于quantize_per_tensor_dynamic函数可以实现动态量化,我们自己写个量化函数然后和pytorch对比来看其是如何实现的:
import torch
import numpy as np
import math
#自己写的动态量化函数,与pytorch自带quantize_tensor函数做对比
def quantize_tensor(array, num_bits=8):
highB = array.max()
lowB = array.min()
rangeB = highB - lowB
shiftDist = -(highB + lowB) / 2
qmax = 2.**num_bits - 1.
zero_point = shiftDist / rangeB * qmax;
#if zero_point < 0:
# zero_point = zero_point - 1.0
print(zero_point)
zero_point = zero_point.floor().int()
scale = rangeB / qmax
q_x = array/scale + zero_point
q_x = q_x.round().int()
return q_x, zero_point, scale
x1 = torch.randn(1, 10, dtype=torch.float32)
x2 = torch.randn(10, 10, dtype=torch.float32)
xq1 = torch.quantize_per_tensor_dynamic(x1, dtype=torch.qint8, reduce_range = False)
xq2 = torch.quantize_per_tensor_dynamic(x2, dtype=torch.qint8, reduce_range = False)
print('************martrix value**************')
print(x1)
#calcScaleAndZero(x2.numpy())
print(x2)
print('************scale**********************')
scale1 = xq1.q_scale()
scale2 = xq2.q_scale()
print(scale1)
print(scale2)
print('************zero point**********************')
zpoint1 = xq1.q_zero_point()
zpoint2 = xq2.q_zero_point()
print(zpoint1)
print(zpoint2)
print('*************calc quant and zero point*********************')
q1, z1, s1 = quantize_tensor(x1)
print(q1, z1, s1)
q2, z2, s2 = quantize_tensor(x2)
print(q2, z2, s2)
xquant1 = xq1.int_repr().int()
xquant2 = xq2.int_repr().int()
print('************quant**********************')
print(xquant1)
print(xquant2)
print('************mult result****************')
multZ = torch.matmul(xquant1 - zpoint1, xquant2 - zpoint2)
print(multZ)
print(scale1)
print(scale2)
print('quant result:')
print(multZ*scale1*scale2)
realResult = torch.matmul(x1,x2)
print('real result')
print(realResult)
结果如下:
************martrix value**************
tensor([[ 1.2477, 0.0531, 0.7887, -1.9008, 0.0422, 0.0558, 2.1269, -0.5745,
-1.1107, -0.9602]])
tensor([[-0.0498, -1.9346, 1.1775, -0.2848, 1.9393, 0.1473, -0.6528, 1.4783,
-1.0426, -0.1134],
[-1.4242, -1.1538, -1.0923, 0.7910, -0.8136, 0.2567, 0.7243, 2.5828,
-0.5604, 0.1569],
[ 0.4030, 0.2074, 1.6686, -0.0956, 1.3616, -0.1492, 1.0531, -0.6623,
-1.1229, -1.9445],
[-0.3417, 0.4932, 1.1417, 0.0104, 0.2803, -0.1214, -0.2549, -1.2193,
0.8666, -0.9464],
[-0.6474, -0.9055, -1.0907, -0.8223, -1.3726, 0.2854, -0.3068, -0.7960,
0.3766, 0.9145],
[ 0.4355, 0.3613, -1.0598, 0.8375, -0.6023, 0.6905, -0.4290, 0.7039,
0.5284, 1.2257],
[ 1.5872, 0.2304, 0.8338, -1.7823, 2.5621, 0.4503, -0.2524, -0.5032,
1.1579, -0.4619],
[-0.3367, 0.9936, -0.9854, -0.9287, -0.2374, 2.7017, 0.3184, -0.1240,
0.8407, -1.0258],
[-0.3044, 0.3404, -3.9793, -0.0676, 1.1238, -0.1845, -1.1807, -1.3403,
0.3283, 1.3031],
[ 0.5594, -0.8091, 0.5098, 0.7334, -0.1245, 0.5204, -0.0674, -0.6535,
0.3256, -0.3021]])
************scale**********************
0.015794984967100852
0.026200167338053384
************zero point**********************
-8
24
*************calc quant and zero point*********************
tensor(-7.1556)
tensor([[ 71, -5, 42, -128, -5, -4, 127, -44, -78, -69]],
dtype=torch.int32) tensor(-8, dtype=torch.int32) tensor(0.0158)
tensor(24.3810)
tensor([[ 22, -50, 69, 13, 98, 30, -1, 80, -16, 20],
[ -30, -20, -18, 54, -7, 34, 52, 123, 3, 30],
[ 39, 32, 88, 20, 76, 18, 64, -1, -19, -50],
[ 11, 43, 68, 24, 35, 19, 14, -23, 57, -12],
[ -1, -11, -18, -7, -28, 35, 12, -6, 38, 59],
[ 41, 38, -16, 56, 1, 50, 8, 51, 44, 71],
[ 85, 33, 56, -44, 122, 41, 14, 5, 68, 6],
[ 11, 62, -14, -11, 15, 127, 36, 19, 56, -15],
[ 12, 37, -128, 21, 67, 17, -21, -27, 37, 74],
[ 45, -7, 43, 52, 19, 44, 21, -1, 36, 12]],
dtype=torch.int32) tensor(24, dtype=torch.int32) tensor(0.0262)
************quant**********************
tensor([[ 71, -5, 42, -128, -5, -4, 127, -44, -78, -69]],
dtype=torch.int32)
tensor([[ 22, -50, 69, 13, 98, 30, -1, 80, -16, 20],
[ -30, -20, -18, 54, -7, 34, 52, 123, 3, 30],
[ 39, 32, 88, 20, 76, 18, 64, -1, -19, -50],
[ 11, 43, 68, 24, 35, 19, 14, -23, 57, -12],
[ -1, -11, -18, -7, -28, 35, 12, -6, 38, 59],
[ 41, 38, -16, 56, 1, 50, 8, 51, 44, 71],
[ 85, 33, 56, -44, 122, 41, 14, 5, 68, 6],
[ 11, 62, -14, -11, 15, 127, 36, 19, 56, -15],
[ 12, 37, -128, 21, 67, 17, -21, -27, 37, 74],
[ 45, -7, 43, 52, 19, 44, 21, -1, 36, 12]],
dtype=torch.int32)
************mult result****************
tensor([[ 10245, -7079, 16232, -10362, 17634, -1202, 2760, 11839, -6065,
-3179]], dtype=torch.int32)
0.015794984967100852
0.026200167338053384
quant result:
tensor([[ 4.2397, -2.9295, 6.7173, -4.2881, 7.2975, -0.4974, 1.1422, 4.8993,
-2.5099, -1.3156]])
real result
tensor([[ 4.1968, -2.9492, 6.7218, -4.2829, 7.2829, -0.5282, 1.1583, 4.8998,
-2.5157, -1.3109]])
可以通过我们给的自定义函数quantize_tensor看出动态量化的原理,最后我们比较了量化矩阵相乘后的结果,可以看到我们可以保证小数点后一位的精度,当然精度取决于我们的scale的大小,越小精度越高,当然随之而来的要求数据的取值范围要小,否则会出现溢出的情况。