压缩
模型训练通常采用float32, 在部署时不需要高的数据精度,可以将类型转化为float16进行保存,这样可以降低45%左右的权重大小。
- 训练并保存模型权重
import timm
model = timm.create_model("mobilevit_xss“, pretrained=False, num_classes=8)
model.load_state_dict(torch.load("model_mobilevit_xss.pth"))
- 转换数据类型并存储
params = torch.load("model_mobilevit_xss.pth")
for key in params.keys():
params[key] = params[key].half
torch.save(params, 'model_mobilevit_xss_half.pth")
裁剪
在模型训练完之后可以对权重进行裁剪,方法如下:
- 按照比例随机裁剪
- 按照权重大小裁剪
import torch.nn.utils.prune as prune
import numpy as np
model = timm.create_model('mobilevit_xxs', pretrained=False, num_classes=8)
model.load_state_dict(torch.load('model_mobilevit_xxs.pth'))
# 选中需要裁剪的层
module = model.head.fc
# random_unstructured裁剪
prune.random_unstructured(module, name="weight", amount=0.3)
# l1_unstructured裁剪
prune.l1_unstructured(module, name="weight", amount=0.3)
# ln_structured裁剪
prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)
在使用权重裁剪需要注意:
权重裁剪并不会改变模型的权重大小,只是增加了稀疏性;
权重裁剪并不会减少模型的预测速度,只是减少了计算量;
权重裁剪的参数比例会对模型精度有影响,需要测试和验证;
量化
32-bit的乘加变成了8-bit的乘加,模型权重大小减少,对内存的要求降低了。
1.Eager Mode Quantization
import torch
# define a floating point model
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.fc1 = torch.nn.Linear(100, 40)
self.fc2 = torch.nn.Linear(1000, 400)
def forward(self, x):
x = self.fc1(x)
return x
# create a model instance
model_fp32 = M()
torch.save(model_fp32.state_dict(), 'tmp_float32.pth')
# create a quantized model instance
model_int8 = torch.quantization.quantize_dynamic(
model_fp32, # the original model
{torch.nn.Linear}, # a set of layers to dynamically quantize
dtype=torch.qint8) # the target dtype for quantized weights
# run the model
input_fp32 = torch.randn(4, 4, 4, 4)
res = model_int8(input_fp32)
torch.save(model_int8.state_dict(), 'tmp_int8.pth')
2.Post Training Static Quantization
import torch
# define a floating point model where some layers could be statically quantized
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
# QuantStub converts tensors from floating point to quantized
self.quant = torch.quantization.QuantStub()
self.conv = torch.nn.Conv2d(1, 100, 1)
self.relu = torch.nn.ReLU()
self.fc = torch.nn.Linear(100, 10)
# DeQuantStub converts tensors from quantized to floating point
self.dequant = torch.quantization.DeQuantStub()
def forward(self, x):
# manually specify where tensors will be converted from floating
# point to quantized in the quantized model
x = self.quant(x)
x = self.conv(x)
x = self.relu(x)
# manually specify where tensors will be converted from quantized
# to floating point in the quantized model
x = self.dequant(x)
return x
# create a model instance
model_fp32 = M()
torch.save(model_fp32.state_dict(), 'tmp_float32.pth')
model_fp32.eval()
model_fp32.qconfig = torch.quantization.get_default_qconfig('fbgemm')
model_fp32_fused = torch.quantization.fuse_modules(model_fp32, [['conv', 'relu']])
model_fp32_prepared = torch.quantization.prepare(model_fp32_fused)
input_fp32 = torch.randn(4, 1, 4, 4)
model_fp32_prepared(input_fp32)
model_int8 = torch.quantization.convert(model_fp32_prepared)
res = model_int8(input_fp32)
torch.save(model_int8.state_dict(), 'tmp_int8.pth')