env:
- pytorch==1.7.1
- torchvision==0.8.2
- python==3.6
注意:
- 精度变差
- 操作比较简单,但还是需要动模型
- 层合并的部分需要对结构有了解
- 模型大小变为原来的1/4
- 推理速度提高20+%
step1:加载模型
就正常加载即可,没啥特别的
model = Resnet().to(device)
checkpoint = torch.load(weights, map_location=device)
model.load_state_dict(checkpoint)
model.to(device).eval()
step2:量化
照猫画虎即可,没啥特别的
backend = "fbgemm"
model.qconfig = torch.quantization.get_default_qconfig(backend) # 不同平台不同配置
listmix = [['conv','relu']] # 可以是conv+bn conv+relu conv+bn+relu
model = torch.quantization.fuse_modules(model,listmix) # 合并某些层,不想合并这句也可以跳过
model_fp32_prepared = torch.quantization.prepare(model)
model_int8 = torch.quantization.convert(model_fp32_prepared)
step3:持久化(保存模型)
两种,一种保存变量,一种保存变量+结构
保存变量+结构会节省加载模型的时间
# 保存
traced_model = torch.jit.trace(model_int8, img)
torch.jit.save(traced_model, "traced_int8.pt")
# 加载
model = torch.jit.load("traced_int8.pt")
model(img)
保存变量
# 保存
torch.save(model_int8.state_dict(), "int_8_post.pt")
# 加载
'''定义模型结构'''
model = YourNet().to(device)
checkpoint = torch.load("int_8_post.pt", map_location=device)
model.load_state_dict(checkpoint)
model.to(device).eval()
''' 把之前量化的操作粘贴进来'''
backend = "fbgemm"
model.qconfig = torch.quantization.get_default_qconfig(backend) # 不同平台不同配置
listmix = [['conv','relu']] # 可以是conv+bn conv+relu conv+bn+relu
model = torch.quantization.fuse_modules(model,listmix) # 合并某些层
model_fp32_prepared = torch.quantization.prepare(model)
model_int8 = torch.quantization.convert(model_fp32_prepared)
''' 加载变量'''
checkpoint = torch.load("int_8_post.pt", map_location=device)
model_int8.load_state_dict(checkpoint)
model_int8.eval()
model_int8(img)
step4:input压缩与解压缩
这步需要对模型输入修改一下,因为量化的模型需要量化的输入,python的计算需要解量化
class YourNet(nn.Module):
def __init__(self, cfg, img_size=(416, 416), verbose=False):
... ...
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
... ...
def forward(self,input):
x = self.quant(input)
x = self.layer(x)
x = self.dequant(x)
... ...
参考:
https://pytorch.org/docs/stable/quantization.html