Pytorch量化之静态量化

env:

  • pytorch==1.7.1
  • torchvision==0.8.2
  • python==3.6

注意:

  • 精度变差
  • 操作比较简单,但还是需要动模型
  • 层合并的部分需要对结构有了解
  • 模型大小变为原来的1/4
  • 推理速度提高20+%

step1:加载模型

就正常加载即可,没啥特别的

model = Resnet().to(device)
checkpoint = torch.load(weights, map_location=device)
model.load_state_dict(checkpoint)
model.to(device).eval()

step2:量化

照猫画虎即可,没啥特别的

backend = "fbgemm"
model.qconfig = torch.quantization.get_default_qconfig(backend)  # 不同平台不同配置

listmix = [['conv','relu']] # 可以是conv+bn conv+relu conv+bn+relu 
model = torch.quantization.fuse_modules(model,listmix) # 合并某些层,不想合并这句也可以跳过

model_fp32_prepared = torch.quantization.prepare(model)
model_int8 = torch.quantization.convert(model_fp32_prepared)

step3:持久化(保存模型)

两种,一种保存变量,一种保存变量+结构

保存变量+结构会节省加载模型的时间

# 保存
traced_model = torch.jit.trace(model_int8, img)
torch.jit.save(traced_model, "traced_int8.pt")

# 加载
model = torch.jit.load("traced_int8.pt")
model(img)

保存变量

# 保存
torch.save(model_int8.state_dict(), "int_8_post.pt")

# 加载
'''定义模型结构'''
model = YourNet().to(device)
checkpoint = torch.load("int_8_post.pt", map_location=device)
model.load_state_dict(checkpoint)
model.to(device).eval()

''' 把之前量化的操作粘贴进来'''
backend = "fbgemm"
model.qconfig = torch.quantization.get_default_qconfig(backend)  # 不同平台不同配置

listmix = [['conv','relu']] # 可以是conv+bn conv+relu conv+bn+relu 
model = torch.quantization.fuse_modules(model,listmix) # 合并某些层

model_fp32_prepared = torch.quantization.prepare(model)
model_int8 = torch.quantization.convert(model_fp32_prepared)
''' 加载变量'''
checkpoint = torch.load("int_8_post.pt", map_location=device)
model_int8.load_state_dict(checkpoint)
model_int8.eval()
model_int8(img)

step4:input压缩与解压缩

这步需要对模型输入修改一下,因为量化的模型需要量化的输入,python的计算需要解量化

class YourNet(nn.Module):
 
    def __init__(self, cfg, img_size=(416, 416), verbose=False):
        ... ...
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()
        ... ...
    def forward(self,input):
        x = self.quant(input)
        x = self.layer(x)
        x = self.dequant(x)
        ... ...

参考:

https://pytorch.org/docs/stable/quantization.html

https://github.com/pytorch/pytorch/issues/43016

https://github.com/pytorch/pytorch/issues/28331

  • 2
    点赞
  • 40
    收藏
    觉得还不错? 一键收藏
  • 8
    评论
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值