时间 | 版本 | 修改人 | 描述 |
---|---|---|---|
2024年5月14日10:44:30 | V0.1 | 宋全恒 | 新建文档 |
2024年5月14日16:28:16 | V1.0 | 宋全恒 | 填充了PyTorch对于两种量化方式的内容 |
简介
Pytorch动态量化

设计神经网络时,可以进行许多权衡。在模型开发和训练期间,您可以改变复发性神经网络中的层数和参数数量,并针对模型大小和/或模型延迟或吞吐量而权衡。
量化为您提供了一种方法,可以在训练完成后使用已知模型在性能和模型准确性之间进行类似的权衡。
量化
动态量化
定义
量化网络意味着将其转换为使用降低精度的整数表示来表示权重和/或激活。从浮点数转换为整数时,基本上是将浮点数乘以某个比例系数,然后将结果四舍五入为整数。
确定scale factor是各种量化方法的差异点。
动态量化的关键思想是,对于激活来说,我们将会根据运行时观察到的数据范围来确定scale factor。
这样可以确保 "调整 "比例因子,从而尽可能多地保留每个观测数据集的信号,而模型参数在模型转化过程中是已知的,他们提前转化并存储成INT8形式。
量化模型中的算术是使用矢量化 INT8 指令完成的。累加通常使用 INT16 或 INT32 完成,以避免溢出。如果下一层被量化或转换为 FP32 进行输出,则此更高精度值将缩小为 INT8。
动态量化相对不需要调整参数,这使得它非常适合作为将 LSTM 模型转换为部署的标准部分添加到生产管道中。
代码实践
# import the modules used here in this recipe
import torch
import torch.quantization
import torch.nn as nn
import copy
import os
import time
# define a very, very simple LSTM for demonstration purposes
# in this case, we are wrapping ``nn.LSTM``, one layer, no preprocessing or postprocessing
# inspired by
# `Sequence Models and Long Short-Term Memory Networks tutorial <https://pytorch.org/tutorials/beginner/nlp/sequence_models_tutorial.html`_, by Robert Guthrie
# and `Dynamic Quanitzation tutorial <https://pytorch.org/tutorials/advanced/dynamic_quantization_tutorial.html>`__.
class lstm_for_demonstration(nn.Module):
"""Elementary Long Short Term Memory style model which simply wraps ``nn.LSTM``
Not to be used for anything other than demonstration.
"""
def __init__(self,in_dim,out_dim,depth):
super(lstm_for_demonstration,self).__init__()
self.lstm = nn.LSTM(in_dim,out_dim,depth)
def forward(self,inputs,hidden):
out,hidden = self.lstm(inputs,hidden)
return out, hidden
torch.manual_seed(29592) # set the seed for reproducibility
#shape parameters
model_dimension=8
sequence_length=20
batch_size=1
lstm_depth=1
# random data for input
inputs = torch.randn(sequence_length,batch_size,model_dimension)
# hidden is actually is a tuple of the initial hidden state and the initial cell state
hidden = (torch.randn(lstm_depth,batch_size,model_dimension), torch.randn(lstm_depth,batch_size,model_dimension))
# here is our floating point instance
float_lstm <