随着深度学习的发展,模型变得越来越复杂,随之而来的模型参数也越来越多,对于需要训练的模型硬件要求也越来越高。模型压缩技术就是为了解决模型使用成本的问题。通过提高推理速度,降低模型参数量和运算量。
模型压缩的方法:
- 网络剪枝(Network Pruning)
- 量化(Quantization)
- 低秩分解(Low-rank factorization)
- 知识蒸馏(Knowledge distillation)
1. Network Pruning
- 研究的核心问题就是:如何有效地裁剪模型参数且最小化精度的损失。
- 网络剪枝可以分为 结构化剪枝(Structured pruning) 和 非结构化剪枝(Unstructured pruning) 两类。
1.1 概念
结构化剪枝(Unstructured pruning): 它裁剪的粒度为单个神经元。如果对kernel进行非结构化剪枝,则得到的kernel是稀疏的,即中间有很多元素为0的矩阵。除非下层的硬件和计算库对其有比较好的支持,pruning后版本很难获得实质的性能提升。稀疏矩阵无法利用现有成熟的BLAS库获得额外性能收益。
结构化剪枝(Structured pruning): 又可进一步细分:如可以是channel-wise的,也可以是filter-wise的,还可以是在shape-wise的。
论文及参考:闲话模型压缩之网络剪枝(Network Pruning)篇_ariesjzj的博客-CSDN博客_network pruning
1.2 使用
Pytorch的模型剪枝方法:
- 第一种,对特定网络模块的剪枝(Pruning Model)
- 第二种,多参数模块的剪枝(Pruning multiple parameters)
- 第三种,全局剪枝(GLobal pruning)
- 第四种,用户自定义剪枝(Custom pruning)
# 第一种: 对特定网络模块的剪枝(Pruning Model).
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
# 1: 图像的输入通道(1是黑白图像), 6: 输出通道, 3x3: 卷积核的尺寸
self.conv1 = nn.Conv2d(1, 6, 3)
self.conv2 = nn.Conv2d(6, 16, 3)
self.fc1 = nn.Linear(16 * 5 * 5, 120) # 5x5 是经历卷积操作后的图片尺寸
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, int(x.nelement() / x.shape[0]))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
model = LeNet().to(device=device)
module = model.conv1
print(list(module.named_parameters()))
print(list(module.named_buffers()))
# 第一个参数: module, 代表要进行剪枝的特定模块, 之前我们已经制定了module=model.conv1,
# 说明这里要对第一个卷积层执行剪枝.
# 第二个参数: name, 指定要对选中的模块中的哪些参数执行剪枝.
# 这里设定为name="weight", 意味着对连接网络中的weight剪枝, 而不对bias剪枝.
# 第三个参数: amount, 指定要对模型中多大比例的参数执行剪枝.
# amount是一个介于0.0-1.0的float数值, 或者一个正整数指定剪裁掉多少条连接边.
prune.random_unstructured(module, name="weight", amount=0.3)
print(list(module.named_parameters()))
print(list(module.named_buffers()))
# 模型经历剪枝操作后, 原始的权重矩阵weight参数不见了,
# 变成了weight_orig. 并且刚刚打印为空列表的module.named_buffers(),
# 此时拥有了一个weight_mask参数.
print(module.weight)
# 经过剪枝操作后的模型, 原始的参数存放在了weight_orig中,
# 对应的剪枝矩阵存放在weight_mask中, 而将weight_mask视作掩码张量,
# 再和weight_orig相乘的结果就存放