PyTorch 模型剪枝实例教程二、结构化剪枝

目前大部分最先进的(SOTA)深度学习技术虽然效果好,但由于其模型参数量和计算量过高,难以用于实际部署。而众所周知,生物神经网络使用高效的稀疏连接(生物大脑神经网络balabala啥的都是稀疏连接的),考虑到这一点,为了减少内存、容量和硬件消耗,同时又不牺牲模型预测的精度,在设备上部署轻量级模型,并通过私有的设备上计算以保证隐私,通过减少参数数量来压缩模型的最佳技术非常重要

稀疏神经网络在预测精度方面可以达到密集神经网络的水平,但由于模型参数量小,理论上来讲推理速度也会快很多。而模型剪枝是一种将密集神经网络训练成稀疏神经网络的方法。

本文将通过学习官方示例教程,介绍如何通过一个简单的实例教程来进行模型剪枝,实践深度学习模型压缩加速。

相关链接

深度学习模型压缩与加速技术(一):参数剪枝

PyTorch模型剪枝实例教程一、非结构化剪枝

PyTorch模型剪枝实例教程二、结构化剪枝

PyTorch模型剪枝实例教程三、多参数与全局剪枝

1.导包&定义一个简单的网络
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

'''搭建类LeNet网络'''
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 单通道图像输入,5×5核尺寸
        self.conv1 = nn.Conv2d(1, 3, 5)
        self.conv2 = nn.Conv2d(3, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
2.获取网络需要剪枝的模块
model = LeNet().to(device=device)
module = model.conv1
print(list(module.named_parameters()))      # 6×5×5的weight + 6×1的bias 的参数量

输出:

[('weight', Parameter containing:
tensor([[[[ 0.1473,  0.1251,  0.0492, -0.1375, -0.0781],
          [ 0.0446, -0.1328,  0.0227,  0.0141, -0.1751],
          [ 0.0253,  0.0313,  0.0391,  0.1607, -0.0716],
          [-0.1125, -0.1641,  0.1691,  0.1583,  0.0449],
          [-0.0094, -0.1916,  0.1701,  0.0704,  0.0407]]],


        [[[-0.1945,  0.0709,  0.1071,  0.0038, -0.0686],
          [ 0.0187,  0.0710, -0.0955, -0.0778,  0.1927],
          [ 0.1643,  0.0791,  0.1235,  0.0241, -0.0021],
          [-0.1124,  0.0246, -0.0349, -0.1561,  0.0178],
          [-0.1779,  0.1216,  0.1086, -0.1837,  0.1789]]],


        [[[-0.0051, -0.1969, -0.0155,  0.1890,  0.1977],
          [-0.0654,  0.1219,  0.0849, -0.1937, -0.0933],
          [-0.0409,  0.1344,  0.1688,  0.1917, -0.1727],
          [ 0.1380, -0.1413, -0.1483, -0.0711, -0.0648],
          [-0.1571,  0.0570,  0.1783, -0.0786,  0.1367]]]], requires_grad=True)), ('bias', Parameter containing:
tensor([ 0.0346, -0.1446,  0.0633], requires_grad=True))]
3.模块结构化剪枝(核心)

剪枝一个模块,需要三步:

  • step1.在torch.nn.utils.prune中选定一个剪枝方案,或者自定义(通过子类BasePruningMethod)
  • step2.指定需要剪枝的模块和对应的名称
  • step3.输入对应函数需要的参数

这里,我们根据通道的L2范数,沿着张量的第0轴(第0轴对应卷积层的输出通道,conv1的维数为3×5×5)对weight参数进行结构化剪枝,使用ln_structured()方法。剪枝比例为33%,dim为0,基于L2范数(n=2)

prune.ln_structured(module, name="weight", amount=0.33, n=2, dim=0)
print(module.weight)

输出:

tensor([[[[ 0.1327,  0.0812, -0.0225, -0.0809,  0.1461],
          [ 0.1335, -0.1709,  0.0575, -0.1608, -0.0677],
          [-0.0397, -0.0982,  0.0654, -0.1030, -0.1656],
          [-0.0570,  0.1940,  0.0085,  0.1896,  0.1979],
          [-0.0673,  0.0910, -0.0177, -0.1748,  0.1667]]],

        [[[-0.0000,  0.0000,  0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000, -0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000, -0.0000, -0.0000, -0.0000],
          [-0.0000, -0.0000, -0.0000, -0.0000,  0.0000]]],

        [[[ 0.1668,  0.1742, -0.1581, -0.1208,  0.0745],
          [ 0.0459, -0.0275, -0.1190, -0.1631, -0.1956],
          [-0.0480, -0.1716, -0.0168,  0.0089,  0.0876],
          [-0.0129, -0.1616, -0.1164, -0.1869, -0.1782],
          [ 0.0411, -0.0278, -0.1266,  0.1329, -0.1240]]]],
       grad_fn=<MulBackward0>)

所有相关的张量,包括mask缓冲区和用于计算剪枝张量的原始参数都存储在模型的state_dict中,因此,如果需要,可以很容易地序列化和保存。

print(model.state_dict().keys())

输出:

odict_keys(['conv1.bias', 'conv1.weight_orig', 'conv1.weight_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])

要使修剪永久存在,可以删除weight_trans(orig)和weight_mask重新参数化,并删除forward_pre_hook,可以使用torch.nn.util.prune中的remove函数。注意,这并没有取消修剪,就像它从未发生过一样。它只是简单地使它永久存在,相反,在它的剪枝版本中,通过将参数的权重重新分配给模型参数。

print(list(module.named_parameters()))
print(list(module.named_buffers()))
print(model.state_dict().keys())

输出:

[('bias', Parameter containing:
tensor([0.1673, 0.0794, 0.0110], requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[ 1.2122e-01, -9.2122e-02,  6.4242e-02,  2.0936e-02,  7.7185e-03],
          [ 1.6201e-01, -1.2338e-01, -1.2014e-01,  3.0895e-03, -3.8402e-02],
          [ 7.5407e-03,  1.9274e-01, -3.0035e-02,  1.9638e-02, -5.5985e-03],
          [-1.2915e-01, -7.7561e-02,  6.8224e-02, -1.8743e-01, -1.6051e-01],
          [ 1.4066e-01,  1.1038e-01, -1.8010e-01,  9.4039e-02, -1.2981e-01]]],

        [[[-1.3836e-01, -1.8937e-01,  3.2540e-02, -6.2541e-02,  1.6695e-01],
          [ 1.3803e-01,  1.0196e-01,  8.2551e-02, -1.2815e-06, -1.4024e-02],
          [-3.7121e-02, -1.8625e-01,  4.1115e-02, -1.5329e-01,  3.8362e-02],
          [-5.7373e-02,  9.3459e-02,  5.9365e-02, -9.4975e-02,  1.7842e-01],
          [ 2.2319e-02, -5.2064e-02, -1.9440e-01, -1.7895e-03,  8.3709e-02]]],

        [[[ 1.4024e-01,  6.4016e-02,  1.6549e-01,  9.6163e-02,  1.8803e-01],
          [-5.8840e-02, -1.8487e-01,  1.8037e-01,  7.3717e-02,  1.9991e-01],
          [ 7.9629e-02, -1.1025e-01,  1.2504e-01,  4.6581e-02,  2.2388e-04],
          [-3.6367e-02,  9.8296e-02,  6.5209e-02,  1.7801e-01,  1.3420e-01],
          [ 1.4725e-01, -1.9269e-01,  1.9282e-02, -1.3924e-01, -6.2607e-02]]]],
       requires_grad=True))]
[('weight_mask', tensor([[[[1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.]]],

        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],

        [[[1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.]]]]))]
odict_keys(['conv1.bias', 'conv1.weight_orig', 'conv1.weight_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])

使用remove函数后

prune.remove(module, 'weight')
print(list(module.named_parameters()))
print(list(module.named_buffers()))
print(model.state_dict().keys())

输出:

[('bias', Parameter containing:
tensor([ 0.1144, -0.1641,  0.0962], requires_grad=True)), ('weight', Parameter containing:
tensor([[[[ 0.0142,  0.1698, -0.0730, -0.0358, -0.1309],
          [ 0.1520,  0.1900, -0.0843,  0.0950,  0.1674],
          [-0.1724,  0.1453, -0.1764,  0.0345, -0.1767],
          [ 0.0727,  0.1170,  0.1585, -0.0713, -0.0158],
          [ 0.1485, -0.0270, -0.0164,  0.0889,  0.1170]]],

        [[[ 0.0000, -0.0000, -0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000, -0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000, -0.0000, -0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000,  0.0000,  0.0000]]],

        [[[-0.1681,  0.1801, -0.0567, -0.0366,  0.0085],
          [ 0.0495,  0.0320, -0.0127, -0.1761, -0.0948],
          [ 0.1340,  0.1103,  0.1332, -0.1911, -0.1225],
          [ 0.0781, -0.0920, -0.1759,  0.0977,  0.0030],
          [-0.0436, -0.1694, -0.0094, -0.0553, -0.0591]]]], requires_grad=True))]
[]
odict_keys(['conv1.bias', 'conv1.weight', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])

可以发现,缓冲区中保存的mask参数没了

4.总结

本示例首先搭建了一个类LeNet网络模型,为了进行结构化剪枝,我们选取了LeNet的conv1模块,该模块参数包含为3×5×5的weight卷积核参数和3×1的bias参数,通过示例,我们利用torch.nn.prune中的ln_structured剪枝方法,实现了对weight的3个通道中其中一个通道进行了L2 norm结构化剪枝。

本文用到的核心函数方法:

  • model.state_dict().keys(),模型所有的张量参数,还包括了mask缓冲区的参数
  • prune.ln_structured(),Lx Norm结构化剪枝
  • prune.remove(),从模块中移除修剪重参数化。已修剪的名为name的参数将保持永久修剪,而名为name+_trans (orig)的参数将从参数列表中删除。类似地,名为name+_mask的缓冲区将从缓冲区中删除。

参考:

Torch官方剪枝教程

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小风_

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值