torch.quantize_per_tensor(input,scale, zero_point, dtype)实现8位量化:
摘要:对该函数各个参数的分析
量化:
计算机运算时,默认32位浮点数,若将32位浮点数,变成8位定点数,会快很多。
目前pytorch中的反向传播不支持量化,所以该量化只用于评估训练好的模型,或者将32位浮点数模型存储为8位定点数模型,读取8位定点数模型后需要转换为32位浮点数才能进行神经网络参数的训练。
量化函数原型:Q = torch.quantize_per_tensor(input,scale = 0.025 , zero_point = 0, dtype = torch.quint8)
**
- input为准备量化的32位浮点数,Q为量化后的8位定点数
- dtype为量化类型,quint8代表8位无符号数,qint8代表8位带符号数,最高位是符号位
- 假设量化为qint8,设量化后的数Q为0001_1101,最高位为0(符号位),所以是正数;后7位转换为10进制是29,所以Q代表的数为 :zero_point + Q * scale = 0 + 29 * 0.025 = 0.725
- 所以最终使用print显示Q时,显示的不是0001_1101而是0.725,但它在计算机中存储时,是0001_1101
- 使用dequantize()可以解除量化
- 量化公式为:
**
代码及其运行结果:
总结:
以zero_point为中心,用8位数Q代表input离中心有多远,scale为距离单位
即input ≈ zero_point + Q * scale.